Skip to main content

nemo_flow/codec/
request.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! LLM request codec types and trait.
5//!
6//! This module defines the [`AnnotatedLlmRequest`] type system for structured
7//! LLM request representation and the [`crate::codec::traits::LlmCodec`] trait
8//! for bidirectional translation between opaque [`crate::api::llm::LlmRequest`]
9//! payloads and typed form.
10
11use serde::{Deserialize, Serialize};
12
13use crate::json::Json;
14
15// ---------------------------------------------------------------------------
16// AnnotatedLlmRequest type hierarchy
17// ---------------------------------------------------------------------------
18
19/// Structured view of an LLM request, produced by a Codec from opaque
20/// [`LlmRequest`](crate::api::llm::LlmRequest) content.
21///
22/// The `extra` field captures any provider-specific keys not modeled by the
23/// known fields, ensuring lossless round-trip through `decode`/`encode`.
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct AnnotatedLlmRequest {
26    /// Parsed conversation messages.
27    pub messages: Vec<Message>,
28    /// Model identifier (e.g., `"gpt-4"`, `"claude-sonnet-4-20250514"`).
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub model: Option<String>,
31    /// Common generation parameters, normalized.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub params: Option<GenerationParams>,
34    /// Tool definitions (function schemas) available to the model.
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub tools: Option<Vec<ToolDefinition>>,
37    /// Tool choice control.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub tool_choice: Option<ToolChoice>,
40    /// OpenAI Responses: whether to persist response state server-side.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub store: Option<bool>,
43    /// OpenAI Responses: prior response to continue from.
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub previous_response_id: Option<String>,
46    /// OpenAI Responses: context truncation behavior.
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub truncation: Option<Json>,
49    /// OpenAI Responses: reasoning configuration object.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub reasoning: Option<Json>,
52    /// OpenAI Responses: include filter for additional output/state items.
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub include: Option<Json>,
55    /// OpenAI user identifier.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub user: Option<String>,
58    /// OpenAI metadata map/object.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub metadata: Option<Json>,
61    /// OpenAI service tier preference.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub service_tier: Option<String>,
64    /// OpenAI tool parallelism toggle.
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub parallel_tool_calls: Option<bool>,
67    /// OpenAI Responses max output token limit.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub max_output_tokens: Option<u64>,
70    /// OpenAI Responses max tool calls.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub max_tool_calls: Option<u64>,
73    /// OpenAI logprob fanout count.
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub top_logprobs: Option<u64>,
76    /// OpenAI streaming toggle.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub stream: Option<bool>,
79    /// Extensible key-value pairs for unmodeled provider-specific fields.
80    /// Merged back into the request body during encode via `serde(flatten)`.
81    #[serde(flatten)]
82    pub extra: serde_json::Map<String, Json>,
83}
84
85/// A single message in a conversation, tagged by role.
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(tag = "role", rename_all = "lowercase")]
88pub enum Message {
89    /// A system instruction message.
90    System {
91        /// The message content.
92        content: MessageContent,
93        /// Optional sender name.
94        #[serde(skip_serializing_if = "Option::is_none")]
95        name: Option<String>,
96    },
97    /// A user message.
98    User {
99        /// The message content.
100        content: MessageContent,
101        /// Optional sender name.
102        #[serde(skip_serializing_if = "Option::is_none")]
103        name: Option<String>,
104    },
105    /// An assistant response, optionally containing tool calls.
106    Assistant {
107        /// The message content (optional — may be absent when tool calls are present).
108        #[serde(skip_serializing_if = "Option::is_none")]
109        content: Option<MessageContent>,
110        /// Tool calls requested by the assistant.
111        #[serde(skip_serializing_if = "Option::is_none")]
112        tool_calls: Option<Vec<ToolCall>>,
113        /// Optional sender name.
114        #[serde(skip_serializing_if = "Option::is_none")]
115        name: Option<String>,
116    },
117    /// A tool result message.
118    Tool {
119        /// The tool execution result.
120        content: MessageContent,
121        /// The ID of the tool call this result corresponds to.
122        tool_call_id: String,
123    },
124}
125
126/// Message content: either a plain string or multimodal parts array.
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128#[serde(untagged)]
129pub enum MessageContent {
130    /// Plain text content.
131    Text(String),
132    /// Multimodal content parts.
133    Parts(Vec<ContentPart>),
134}
135
136/// A single content part within a multimodal message.
137///
138/// v1 supports text only. Future versions may add `ImageUrl`, `Audio`, etc.
139#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
140#[serde(tag = "type", rename_all = "snake_case")]
141pub enum ContentPart {
142    /// A text content part.
143    Text {
144        /// The text content.
145        text: String,
146    },
147    /// An image URL content part.
148    ImageUrl {
149        /// Image URL payload.
150        image_url: OpenAiImageUrl,
151    },
152}
153
154/// OpenAI image URL payload.
155#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
156pub struct OpenAiImageUrl {
157    /// URL for the image.
158    pub url: String,
159    /// Optional provider-specific detail hint.
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub detail: Option<String>,
162}
163
164/// A tool call requested by the assistant.
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct ToolCall {
167    /// Unique identifier for this tool call.
168    pub id: String,
169    /// The type of tool call (typically `"function"`).
170    #[serde(rename = "type")]
171    pub call_type: String,
172    /// The function to call.
173    pub function: FunctionCall,
174}
175
176/// A function call within a tool call.
177#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
178pub struct FunctionCall {
179    /// The name of the function to call.
180    pub name: String,
181    /// The function arguments as a JSON string (per OpenAI convention).
182    pub arguments: String,
183}
184
185/// A tool definition (function schema) available to the model.
186#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
187pub struct ToolDefinition {
188    /// The type of tool (typically `"function"`).
189    #[serde(rename = "type")]
190    pub tool_type: String,
191    /// The function definition.
192    pub function: FunctionDefinition,
193}
194
195/// A function definition within a tool definition.
196#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
197pub struct FunctionDefinition {
198    /// The name of the function.
199    pub name: String,
200    /// A description of what the function does.
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub description: Option<String>,
203    /// The JSON Schema for the function parameters.
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub parameters: Option<Json>,
206}
207
208/// Tool choice control: how the model should use available tools.
209#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
210#[serde(rename_all = "lowercase")]
211pub enum ToolChoice {
212    /// Let the model decide whether to call a tool.
213    Auto,
214    /// Do not call any tools.
215    None,
216    /// The model must call at least one tool.
217    Required,
218    /// Force a specific function by name.
219    #[serde(untagged)]
220    Specific(ToolChoiceFunction),
221}
222
223/// A specific tool choice that forces a named function.
224#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
225pub struct ToolChoiceFunction {
226    /// The type (typically `"function"`).
227    #[serde(rename = "type")]
228    pub choice_type: String,
229    /// The function to call.
230    pub function: ToolChoiceFunctionName,
231}
232
233/// The name component of a specific tool choice.
234#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
235pub struct ToolChoiceFunctionName {
236    /// The function name.
237    pub name: String,
238}
239
240/// Normalized generation parameters across providers.
241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
242pub struct GenerationParams {
243    /// Sampling temperature.
244    #[serde(skip_serializing_if = "Option::is_none")]
245    pub temperature: Option<f64>,
246    /// Maximum number of tokens to generate.
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub max_tokens: Option<u64>,
249    /// Nucleus sampling probability.
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub top_p: Option<f64>,
252    /// Stop sequences.
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub stop: Option<Vec<String>>,
255}
256
257// ---------------------------------------------------------------------------
258// Helper methods
259// ---------------------------------------------------------------------------
260
261impl AnnotatedLlmRequest {
262    /// Extract the text content of the first system message, if any.
263    ///
264    /// For [`MessageContent::Text`], returns the string directly.
265    /// For [`MessageContent::Parts`], returns the text of the first
266    /// [`ContentPart::Text`] part.
267    pub fn system_prompt(&self) -> Option<&str> {
268        self.messages.iter().find_map(|m| match m {
269            Message::System { content, .. } => match content {
270                MessageContent::Text(s) => Some(s.as_str()),
271                MessageContent::Parts(parts) => parts.iter().find_map(|p| match p {
272                    ContentPart::Text { text } => Some(text.as_str()),
273                    ContentPart::ImageUrl { .. } => None,
274                }),
275            },
276            _ => None,
277        })
278    }
279
280    /// Get the text content of the last user message, if any.
281    ///
282    /// Searches messages in reverse order and returns the first user
283    /// message found. For [`MessageContent::Parts`], returns the text of
284    /// the first [`ContentPart::Text`] part.
285    pub fn last_user_message(&self) -> Option<&str> {
286        self.messages.iter().rev().find_map(|m| match m {
287            Message::User { content, .. } => match content {
288                MessageContent::Text(s) => Some(s.as_str()),
289                MessageContent::Parts(parts) => parts.iter().find_map(|p| match p {
290                    ContentPart::Text { text } => Some(text.as_str()),
291                    ContentPart::ImageUrl { .. } => None,
292                }),
293            },
294            _ => None,
295        })
296    }
297
298    /// Check if any assistant message in the conversation contains tool calls.
299    ///
300    /// Returns `true` if at least one [`Message::Assistant`] variant has a
301    /// non-empty `tool_calls` field.
302    pub fn has_tool_calls(&self) -> bool {
303        self.messages.iter().any(|m| {
304            matches!(
305                m,
306                Message::Assistant { tool_calls: Some(calls), .. } if !calls.is_empty()
307            )
308        })
309    }
310}
311
312#[cfg(test)]
313#[path = "../../tests/unit/codec/request_tests.rs"]
314mod tests;