Skip to main content

nemo_flow/codec/
request.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! LLM request codec types and trait.
5//!
6//! This module defines the [`AnnotatedLlmRequest`] type system for structured
7//! LLM request representation and the [`crate::codec::traits::LlmCodec`] trait
8//! for bidirectional translation between opaque [`crate::api::llm::LlmRequest`]
9//! payloads and typed form.
10
11use serde::{Deserialize, Serialize};
12
13use crate::json::Json;
14
15// ---------------------------------------------------------------------------
16// AnnotatedLlmRequest type hierarchy
17// ---------------------------------------------------------------------------
18
19/// Structured view of an LLM request, produced by a Codec from opaque
20/// [`LlmRequest`](crate::api::llm::LlmRequest) content.
21///
22/// The `extra` field captures any provider-specific keys not modeled by the
23/// known fields, ensuring lossless round-trip through `decode`/`encode`.
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct AnnotatedLlmRequest {
26    /// Parsed conversation messages.
27    pub messages: Vec<Message>,
28    /// Model identifier (e.g., `"gpt-4"`, `"claude-sonnet-4-20250514"`).
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub model: Option<String>,
31    /// Common generation parameters, normalized.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub params: Option<GenerationParams>,
34    /// Tool definitions (function schemas) available to the model.
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub tools: Option<Vec<ToolDefinition>>,
37    /// Tool choice control.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub tool_choice: Option<ToolChoice>,
40    /// Extensible key-value pairs for unmodeled provider-specific fields.
41    /// Merged back into the request body during encode via `serde(flatten)`.
42    #[serde(flatten)]
43    pub extra: serde_json::Map<String, Json>,
44}
45
46/// A single message in a conversation, tagged by role.
47#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
48#[serde(tag = "role", rename_all = "lowercase")]
49pub enum Message {
50    /// A system instruction message.
51    System {
52        /// The message content.
53        content: MessageContent,
54        /// Optional sender name.
55        #[serde(skip_serializing_if = "Option::is_none")]
56        name: Option<String>,
57    },
58    /// A user message.
59    User {
60        /// The message content.
61        content: MessageContent,
62        /// Optional sender name.
63        #[serde(skip_serializing_if = "Option::is_none")]
64        name: Option<String>,
65    },
66    /// An assistant response, optionally containing tool calls.
67    Assistant {
68        /// The message content (optional — may be absent when tool calls are present).
69        #[serde(skip_serializing_if = "Option::is_none")]
70        content: Option<MessageContent>,
71        /// Tool calls requested by the assistant.
72        #[serde(skip_serializing_if = "Option::is_none")]
73        tool_calls: Option<Vec<ToolCall>>,
74        /// Optional sender name.
75        #[serde(skip_serializing_if = "Option::is_none")]
76        name: Option<String>,
77    },
78    /// A tool result message.
79    Tool {
80        /// The tool execution result.
81        content: MessageContent,
82        /// The ID of the tool call this result corresponds to.
83        tool_call_id: String,
84    },
85}
86
87/// Message content: either a plain string or multimodal parts array.
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
89#[serde(untagged)]
90pub enum MessageContent {
91    /// Plain text content.
92    Text(String),
93    /// Multimodal content parts.
94    Parts(Vec<ContentPart>),
95}
96
97/// A single content part within a multimodal message.
98///
99/// v1 supports text only. Future versions may add `ImageUrl`, `Audio`, etc.
100#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
101#[serde(tag = "type", rename_all = "snake_case")]
102pub enum ContentPart {
103    /// A text content part.
104    Text {
105        /// The text content.
106        text: String,
107    },
108}
109
110/// A tool call requested by the assistant.
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112pub struct ToolCall {
113    /// Unique identifier for this tool call.
114    pub id: String,
115    /// The type of tool call (typically `"function"`).
116    #[serde(rename = "type")]
117    pub call_type: String,
118    /// The function to call.
119    pub function: FunctionCall,
120}
121
122/// A function call within a tool call.
123#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
124pub struct FunctionCall {
125    /// The name of the function to call.
126    pub name: String,
127    /// The function arguments as a JSON string (per OpenAI convention).
128    pub arguments: String,
129}
130
131/// A tool definition (function schema) available to the model.
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct ToolDefinition {
134    /// The type of tool (typically `"function"`).
135    #[serde(rename = "type")]
136    pub tool_type: String,
137    /// The function definition.
138    pub function: FunctionDefinition,
139}
140
141/// A function definition within a tool definition.
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143pub struct FunctionDefinition {
144    /// The name of the function.
145    pub name: String,
146    /// A description of what the function does.
147    #[serde(skip_serializing_if = "Option::is_none")]
148    pub description: Option<String>,
149    /// The JSON Schema for the function parameters.
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub parameters: Option<Json>,
152}
153
154/// Tool choice control: how the model should use available tools.
155#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
156#[serde(rename_all = "lowercase")]
157pub enum ToolChoice {
158    /// Let the model decide whether to call a tool.
159    Auto,
160    /// Do not call any tools.
161    None,
162    /// The model must call at least one tool.
163    Required,
164    /// Force a specific function by name.
165    #[serde(untagged)]
166    Specific(ToolChoiceFunction),
167}
168
169/// A specific tool choice that forces a named function.
170#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
171pub struct ToolChoiceFunction {
172    /// The type (typically `"function"`).
173    #[serde(rename = "type")]
174    pub choice_type: String,
175    /// The function to call.
176    pub function: ToolChoiceFunctionName,
177}
178
179/// The name component of a specific tool choice.
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
181pub struct ToolChoiceFunctionName {
182    /// The function name.
183    pub name: String,
184}
185
186/// Normalized generation parameters across providers.
187#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
188pub struct GenerationParams {
189    /// Sampling temperature.
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub temperature: Option<f64>,
192    /// Maximum number of tokens to generate.
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub max_tokens: Option<u64>,
195    /// Nucleus sampling probability.
196    #[serde(skip_serializing_if = "Option::is_none")]
197    pub top_p: Option<f64>,
198    /// Stop sequences.
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub stop: Option<Vec<String>>,
201}
202
203// ---------------------------------------------------------------------------
204// Helper methods
205// ---------------------------------------------------------------------------
206
207impl AnnotatedLlmRequest {
208    /// Extract the text content of the first system message, if any.
209    ///
210    /// For [`MessageContent::Text`], returns the string directly.
211    /// For [`MessageContent::Parts`], returns the text of the first
212    /// [`ContentPart::Text`] part.
213    pub fn system_prompt(&self) -> Option<&str> {
214        self.messages.iter().find_map(|m| match m {
215            Message::System { content, .. } => match content {
216                MessageContent::Text(s) => Some(s.as_str()),
217                MessageContent::Parts(parts) => parts
218                    .iter()
219                    .map(|p| {
220                        let ContentPart::Text { text } = p;
221                        text.as_str()
222                    })
223                    .next(),
224            },
225            _ => None,
226        })
227    }
228
229    /// Get the text content of the last user message, if any.
230    ///
231    /// Searches messages in reverse order and returns the first user
232    /// message found. For [`MessageContent::Parts`], returns the text of
233    /// the first [`ContentPart::Text`] part.
234    pub fn last_user_message(&self) -> Option<&str> {
235        self.messages.iter().rev().find_map(|m| match m {
236            Message::User { content, .. } => match content {
237                MessageContent::Text(s) => Some(s.as_str()),
238                MessageContent::Parts(parts) => parts
239                    .iter()
240                    .map(|p| {
241                        let ContentPart::Text { text } = p;
242                        text.as_str()
243                    })
244                    .next(),
245            },
246            _ => None,
247        })
248    }
249
250    /// Check if any assistant message in the conversation contains tool calls.
251    ///
252    /// Returns `true` if at least one [`Message::Assistant`] variant has a
253    /// non-empty `tool_calls` field.
254    pub fn has_tool_calls(&self) -> bool {
255        self.messages.iter().any(|m| {
256            matches!(
257                m,
258                Message::Assistant { tool_calls: Some(calls), .. } if !calls.is_empty()
259            )
260        })
261    }
262}
263
264#[cfg(test)]
265#[path = "../../tests/unit/codec/request_tests.rs"]
266mod tests;