Skip to main content

nemo_flow/codec/
openai_chat.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Built-in codec for the OpenAI Chat Completions API.
5//!
6//! Implements [`LlmCodec`] (request decode/encode) and [`LlmResponseCodec`]
7//! (response decode) for the OpenAI Chat Completions format.
8
9use serde::Deserialize;
10
11use crate::api::llm::LlmRequest;
12use crate::error::{FlowError, Result};
13use crate::json::Json;
14
15use super::request::{AnnotatedLlmRequest, GenerationParams, Message, ToolChoice, ToolDefinition};
16use super::response::{
17    AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage,
18};
19use super::traits::{LlmCodec, LlmResponseCodec};
20
21// ---------------------------------------------------------------------------
22// Public codec struct
23// ---------------------------------------------------------------------------
24
25/// Built-in codec for the OpenAI Chat Completions API.
26pub struct OpenAIChatCodec;
27
28// ---------------------------------------------------------------------------
29// Private intermediate serde structs for response decode
30// ---------------------------------------------------------------------------
31
32#[derive(Deserialize)]
33struct RawChatCompletion {
34    id: Option<String>,
35    model: Option<String>,
36    choices: Option<Vec<RawChoice>>,
37    usage: Option<RawChatUsage>,
38    system_fingerprint: Option<String>,
39    service_tier: Option<String>,
40    #[serde(flatten)]
41    extra: serde_json::Map<String, Json>,
42}
43
44#[derive(Deserialize)]
45struct RawChoice {
46    message: Option<RawMessage>,
47    finish_reason: Option<String>,
48    logprobs: Option<Json>,
49}
50
51#[derive(Deserialize)]
52struct RawMessage {
53    content: Option<String>,
54    tool_calls: Option<Vec<RawToolCall>>,
55}
56
57#[derive(Deserialize)]
58struct RawToolCall {
59    id: Option<String>,
60    function: Option<RawFunction>,
61}
62
63#[derive(Deserialize)]
64struct RawFunction {
65    name: Option<String>,
66    arguments: Option<String>,
67}
68
69#[derive(Deserialize)]
70struct RawChatUsage {
71    prompt_tokens: Option<u64>,
72    completion_tokens: Option<u64>,
73    total_tokens: Option<u64>,
74    prompt_tokens_details: Option<RawPromptTokensDetails>,
75}
76
77#[derive(Deserialize)]
78struct RawPromptTokensDetails {
79    cached_tokens: Option<u64>,
80}
81
82// ---------------------------------------------------------------------------
83// Helper functions
84// ---------------------------------------------------------------------------
85
86/// Map OpenAI Chat finish_reason string to normalized [`FinishReason`].
87fn map_chat_finish_reason(reason: &str) -> FinishReason {
88    match reason {
89        "stop" => FinishReason::Complete,
90        "length" => FinishReason::Length,
91        "tool_calls" | "function_call" => FinishReason::ToolUse,
92        "content_filter" => FinishReason::ContentFilter,
93        other => FinishReason::Unknown(other.to_string()),
94    }
95}
96
97/// Parse OpenAI tool call arguments from JSON string to [`Json`] value.
98///
99/// Falls back to [`Json::String`] if parsing fails (malformed model output).
100fn parse_arguments(arguments: &str) -> Json {
101    serde_json::from_str(arguments).unwrap_or_else(|_| Json::String(arguments.to_string()))
102}
103
104/// Keys that are modeled in [`AnnotatedLlmRequest`] and should NOT go into `extra`.
105const MODELED_REQUEST_KEYS: &[&str] = &[
106    "messages",
107    "model",
108    "temperature",
109    "max_tokens",
110    "max_completion_tokens",
111    "top_p",
112    "stop",
113    "tools",
114    "tool_choice",
115];
116
117// ---------------------------------------------------------------------------
118// LlmResponseCodec implementation
119// ---------------------------------------------------------------------------
120
121impl LlmResponseCodec for OpenAIChatCodec {
122    fn decode_response(&self, response: &Json) -> Result<AnnotatedLlmResponse> {
123        let raw: RawChatCompletion = serde_json::from_value(response.clone())
124            .map_err(|e| FlowError::Internal(format!("OpenAI Chat response decode: {e}")))?;
125
126        // Extract first choice (if any).
127        let choice = raw.choices.as_ref().and_then(|c| c.first());
128
129        // Map message content.
130        let message = choice
131            .and_then(|c| c.message.as_ref())
132            .and_then(|m| m.content.as_ref())
133            .map(|s| super::request::MessageContent::Text(s.clone()));
134
135        // Map tool calls, skipping entries that lack a usable function body.
136        // Some providers (proxies, vLLM, NIM) may return partial tool_calls
137        // entries where `function` or `function.name` is absent or null.
138        let tool_calls = choice
139            .and_then(|c| c.message.as_ref())
140            .and_then(|m| m.tool_calls.as_ref())
141            .map(|tcs| {
142                tcs.iter()
143                    .filter_map(|tc| {
144                        let func = tc.function.as_ref()?;
145                        let name = func.name.as_ref()?;
146                        Some(ResponseToolCall {
147                            id: tc.id.clone().unwrap_or_default(),
148                            name: name.clone(),
149                            arguments: func
150                                .arguments
151                                .as_deref()
152                                .map(parse_arguments)
153                                .unwrap_or(Json::Object(Default::default())),
154                        })
155                    })
156                    .collect::<Vec<_>>()
157            });
158
159        // Map finish reason.
160        let finish_reason = choice
161            .and_then(|c| c.finish_reason.as_deref())
162            .map(map_chat_finish_reason);
163
164        // Map usage.
165        let usage = raw.usage.map(|u| Usage {
166            prompt_tokens: u.prompt_tokens,
167            completion_tokens: u.completion_tokens,
168            total_tokens: u.total_tokens,
169            cache_read_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
170            cache_write_tokens: None,
171        });
172
173        // Build API-specific fields.
174        let logprobs = choice.and_then(|c| c.logprobs.clone());
175        let api_specific = Some(ApiSpecificResponse::OpenAIChat {
176            logprobs,
177            system_fingerprint: raw.system_fingerprint,
178            service_tier: raw.service_tier,
179        });
180
181        Ok(AnnotatedLlmResponse {
182            id: raw.id,
183            model: raw.model,
184            message,
185            tool_calls,
186            finish_reason,
187            usage,
188            api_specific,
189            extra: raw.extra,
190        })
191    }
192}
193
194// ---------------------------------------------------------------------------
195// LlmCodec implementation
196// ---------------------------------------------------------------------------
197
198impl LlmCodec for OpenAIChatCodec {
199    fn decode(&self, request: &LlmRequest) -> Result<AnnotatedLlmRequest> {
200        let obj = request
201            .content
202            .as_object()
203            .ok_or_else(|| FlowError::Internal("request content is not an object".into()))?;
204
205        // Extract messages (default to empty vec if absent).
206        let messages: Vec<Message> = obj
207            .get("messages")
208            .map(|v| serde_json::from_value(v.clone()).unwrap_or_default())
209            .unwrap_or_default();
210
211        // Extract model.
212        let model = obj.get("model").and_then(|v| v.as_str()).map(String::from);
213
214        // Extract generation params.
215        let temperature = obj.get("temperature").and_then(|v| v.as_f64());
216        let top_p = obj.get("top_p").and_then(|v| v.as_f64());
217        let stop = obj
218            .get("stop")
219            .and_then(|v| serde_json::from_value::<Vec<String>>(v.clone()).ok());
220
221        // max_completion_tokens takes priority over max_tokens (newer API key).
222        let max_tokens = obj
223            .get("max_completion_tokens")
224            .and_then(|v| v.as_u64())
225            .or_else(|| obj.get("max_tokens").and_then(|v| v.as_u64()));
226
227        let params =
228            if temperature.is_some() || max_tokens.is_some() || top_p.is_some() || stop.is_some() {
229                Some(GenerationParams {
230                    temperature,
231                    max_tokens,
232                    top_p,
233                    stop,
234                })
235            } else {
236                None
237            };
238
239        // Extract tools.
240        let tools: Option<Vec<ToolDefinition>> = obj
241            .get("tools")
242            .map(|v| serde_json::from_value(v.clone()))
243            .transpose()
244            .map_err(|e| FlowError::Internal(format!("OpenAI Chat tools decode: {e}")))?;
245
246        // Extract tool_choice.
247        let tool_choice: Option<ToolChoice> = obj
248            .get("tool_choice")
249            .map(|v| serde_json::from_value(v.clone()))
250            .transpose()
251            .map_err(|e| FlowError::Internal(format!("OpenAI Chat tool_choice decode: {e}")))?;
252
253        // Collect extra fields (keys not in MODELED_REQUEST_KEYS).
254        let extra: serde_json::Map<String, Json> = obj
255            .iter()
256            .filter(|(k, _)| !MODELED_REQUEST_KEYS.contains(&k.as_str()))
257            .map(|(k, v)| (k.clone(), v.clone()))
258            .collect();
259
260        Ok(AnnotatedLlmRequest {
261            messages,
262            model,
263            params,
264            tools,
265            tool_choice,
266            extra,
267        })
268    }
269
270    fn encode(&self, annotated: &AnnotatedLlmRequest, original: &LlmRequest) -> Result<LlmRequest> {
271        let mut content = original.content.clone();
272        let obj = content
273            .as_object_mut()
274            .ok_or_else(|| FlowError::Internal("original content is not an object".into()))?;
275
276        insert_serialized(obj, "messages", &annotated.messages, "messages")?;
277
278        if let Some(ref model) = annotated.model {
279            obj.insert("model".into(), Json::String(model.clone()));
280        }
281
282        if let Some(ref params) = annotated.params {
283            overlay_generation_params(obj, params)?;
284        }
285
286        if let Some(ref tools) = annotated.tools {
287            insert_serialized(obj, "tools", tools, "tools")?;
288        }
289
290        if let Some(ref tool_choice) = annotated.tool_choice {
291            insert_serialized(obj, "tool_choice", tool_choice, "tool_choice")?;
292        }
293
294        for (k, v) in &annotated.extra {
295            obj.insert(k.clone(), v.clone());
296        }
297
298        // Force `stream_options.include_usage` when the caller did not set it.
299        //
300        // Rationale: OpenAI-compatible backends only emit the terminal chunk
301        // containing `usage` (prompt/completion/total tokens) when this flag
302        // is true. Without it, Phoenix spans show `token_count=0` for every
303        // LLM call even though the provider knows the real counts. The
304        // observability exporter (OpenInference) reads usage off the
305        // annotated response, so the flag has to be set at the request level
306        // before bytes go on the wire.
307        //
308        // Guarded on `stream == true` per the OpenAI Chat Completions spec,
309        // which restricts `stream_options` to streaming requests. Caller-
310        // provided `stream_options` are preserved verbatim (including
311        // explicit opt-outs such as `include_usage: false`).
312        let is_streaming = obj.get("stream").and_then(|v| v.as_bool()).unwrap_or(false);
313        if is_streaming && !obj.contains_key("stream_options") {
314            obj.insert(
315                "stream_options".into(),
316                serde_json::json!({"include_usage": true}),
317            );
318        }
319
320        Ok(LlmRequest {
321            headers: original.headers.clone(),
322            content,
323        })
324    }
325}
326
327/// Helper to construct a [`Json`] number from an `f64`.
328fn json_f64(v: f64) -> Json {
329    serde_json::Number::from_f64(v)
330        .map(Json::Number)
331        .unwrap_or(Json::Null)
332}
333
334fn insert_serialized<T: serde::Serialize>(
335    obj: &mut serde_json::Map<String, Json>,
336    key: &str,
337    value: &T,
338    context: &str,
339) -> Result<()> {
340    let json = serde_json::to_value(value)
341        .map_err(|e| FlowError::Internal(format!("OpenAI Chat {context} encode: {e}")))?;
342    obj.insert(key.into(), json);
343    Ok(())
344}
345
346fn overlay_generation_params(
347    obj: &mut serde_json::Map<String, Json>,
348    params: &GenerationParams,
349) -> Result<()> {
350    if let Some(temp) = params.temperature {
351        obj.insert("temperature".into(), json_f64(temp));
352    }
353    if let Some(top_p) = params.top_p {
354        obj.insert("top_p".into(), json_f64(top_p));
355    }
356    if let Some(ref stop) = params.stop {
357        insert_serialized(obj, "stop", stop, "stop")?;
358    }
359    if let Some(max_tokens) = params.max_tokens {
360        let key = if obj.contains_key("max_completion_tokens") {
361            "max_completion_tokens"
362        } else {
363            "max_tokens"
364        };
365        obj.insert(key.into(), Json::from(max_tokens));
366    }
367    Ok(())
368}
369
370// ---------------------------------------------------------------------------
371// Tests
372// ---------------------------------------------------------------------------
373
374#[cfg(test)]
375#[path = "../../tests/unit/codec/openai_chat_tests.rs"]
376mod tests;