1use serde::Deserialize;
10
11use crate::api::llm::LlmRequest;
12use crate::error::{FlowError, Result};
13use crate::json::Json;
14
15use super::request::{AnnotatedLlmRequest, GenerationParams, Message, ToolChoice, ToolDefinition};
16use super::response::{
17 AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage,
18};
19use super::traits::{LlmCodec, LlmResponseCodec};
20
21pub struct OpenAIChatCodec;
27
28#[derive(Deserialize)]
33struct RawChatCompletion {
34 id: Option<String>,
35 model: Option<String>,
36 choices: Option<Vec<RawChoice>>,
37 usage: Option<RawChatUsage>,
38 system_fingerprint: Option<String>,
39 service_tier: Option<String>,
40 #[serde(flatten)]
41 extra: serde_json::Map<String, Json>,
42}
43
44#[derive(Deserialize)]
45struct RawChoice {
46 message: Option<RawMessage>,
47 finish_reason: Option<String>,
48 logprobs: Option<Json>,
49}
50
51#[derive(Deserialize)]
52struct RawMessage {
53 content: Option<String>,
54 tool_calls: Option<Vec<RawToolCall>>,
55}
56
57#[derive(Deserialize)]
58struct RawToolCall {
59 id: Option<String>,
60 function: Option<RawFunction>,
61}
62
63#[derive(Deserialize)]
64struct RawFunction {
65 name: Option<String>,
66 arguments: Option<String>,
67}
68
69#[derive(Deserialize)]
70struct RawChatUsage {
71 prompt_tokens: Option<u64>,
72 completion_tokens: Option<u64>,
73 total_tokens: Option<u64>,
74 prompt_tokens_details: Option<RawPromptTokensDetails>,
75}
76
77#[derive(Deserialize)]
78struct RawPromptTokensDetails {
79 cached_tokens: Option<u64>,
80}
81
82fn map_chat_finish_reason(reason: &str) -> FinishReason {
88 match reason {
89 "stop" => FinishReason::Complete,
90 "length" => FinishReason::Length,
91 "tool_calls" | "function_call" => FinishReason::ToolUse,
92 "content_filter" => FinishReason::ContentFilter,
93 other => FinishReason::Unknown(other.to_string()),
94 }
95}
96
97fn parse_arguments(arguments: &str) -> Json {
101 serde_json::from_str(arguments).unwrap_or_else(|_| Json::String(arguments.to_string()))
102}
103
104const MODELED_REQUEST_KEYS: &[&str] = &[
106 "messages",
107 "model",
108 "temperature",
109 "max_tokens",
110 "max_completion_tokens",
111 "top_p",
112 "stop",
113 "tools",
114 "tool_choice",
115];
116
117impl LlmResponseCodec for OpenAIChatCodec {
122 fn decode_response(&self, response: &Json) -> Result<AnnotatedLlmResponse> {
123 let raw: RawChatCompletion = serde_json::from_value(response.clone())
124 .map_err(|e| FlowError::Internal(format!("OpenAI Chat response decode: {e}")))?;
125
126 let choice = raw.choices.as_ref().and_then(|c| c.first());
128
129 let message = choice
131 .and_then(|c| c.message.as_ref())
132 .and_then(|m| m.content.as_ref())
133 .map(|s| super::request::MessageContent::Text(s.clone()));
134
135 let tool_calls = choice
139 .and_then(|c| c.message.as_ref())
140 .and_then(|m| m.tool_calls.as_ref())
141 .map(|tcs| {
142 tcs.iter()
143 .filter_map(|tc| {
144 let func = tc.function.as_ref()?;
145 let name = func.name.as_ref()?;
146 Some(ResponseToolCall {
147 id: tc.id.clone().unwrap_or_default(),
148 name: name.clone(),
149 arguments: func
150 .arguments
151 .as_deref()
152 .map(parse_arguments)
153 .unwrap_or(Json::Object(Default::default())),
154 })
155 })
156 .collect::<Vec<_>>()
157 });
158
159 let finish_reason = choice
161 .and_then(|c| c.finish_reason.as_deref())
162 .map(map_chat_finish_reason);
163
164 let usage = raw.usage.map(|u| Usage {
166 prompt_tokens: u.prompt_tokens,
167 completion_tokens: u.completion_tokens,
168 total_tokens: u.total_tokens,
169 cache_read_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
170 cache_write_tokens: None,
171 });
172
173 let logprobs = choice.and_then(|c| c.logprobs.clone());
175 let api_specific = Some(ApiSpecificResponse::OpenAIChat {
176 logprobs,
177 system_fingerprint: raw.system_fingerprint,
178 service_tier: raw.service_tier,
179 });
180
181 Ok(AnnotatedLlmResponse {
182 id: raw.id,
183 model: raw.model,
184 message,
185 tool_calls,
186 finish_reason,
187 usage,
188 api_specific,
189 extra: raw.extra,
190 })
191 }
192}
193
194impl LlmCodec for OpenAIChatCodec {
199 fn decode(&self, request: &LlmRequest) -> Result<AnnotatedLlmRequest> {
200 let obj = request
201 .content
202 .as_object()
203 .ok_or_else(|| FlowError::Internal("request content is not an object".into()))?;
204
205 let messages: Vec<Message> = obj
207 .get("messages")
208 .map(|v| serde_json::from_value(v.clone()).unwrap_or_default())
209 .unwrap_or_default();
210
211 let model = obj.get("model").and_then(|v| v.as_str()).map(String::from);
213
214 let temperature = obj.get("temperature").and_then(|v| v.as_f64());
216 let top_p = obj.get("top_p").and_then(|v| v.as_f64());
217 let stop = obj
218 .get("stop")
219 .and_then(|v| serde_json::from_value::<Vec<String>>(v.clone()).ok());
220
221 let max_tokens = obj
223 .get("max_completion_tokens")
224 .and_then(|v| v.as_u64())
225 .or_else(|| obj.get("max_tokens").and_then(|v| v.as_u64()));
226
227 let params =
228 if temperature.is_some() || max_tokens.is_some() || top_p.is_some() || stop.is_some() {
229 Some(GenerationParams {
230 temperature,
231 max_tokens,
232 top_p,
233 stop,
234 })
235 } else {
236 None
237 };
238
239 let tools: Option<Vec<ToolDefinition>> = obj
241 .get("tools")
242 .map(|v| serde_json::from_value(v.clone()))
243 .transpose()
244 .map_err(|e| FlowError::Internal(format!("OpenAI Chat tools decode: {e}")))?;
245
246 let tool_choice: Option<ToolChoice> = obj
248 .get("tool_choice")
249 .map(|v| serde_json::from_value(v.clone()))
250 .transpose()
251 .map_err(|e| FlowError::Internal(format!("OpenAI Chat tool_choice decode: {e}")))?;
252
253 let extra: serde_json::Map<String, Json> = obj
255 .iter()
256 .filter(|(k, _)| !MODELED_REQUEST_KEYS.contains(&k.as_str()))
257 .map(|(k, v)| (k.clone(), v.clone()))
258 .collect();
259
260 Ok(AnnotatedLlmRequest {
261 messages,
262 model,
263 params,
264 tools,
265 tool_choice,
266 extra,
267 })
268 }
269
270 fn encode(&self, annotated: &AnnotatedLlmRequest, original: &LlmRequest) -> Result<LlmRequest> {
271 let mut content = original.content.clone();
272 let obj = content
273 .as_object_mut()
274 .ok_or_else(|| FlowError::Internal("original content is not an object".into()))?;
275
276 insert_serialized(obj, "messages", &annotated.messages, "messages")?;
277
278 if let Some(ref model) = annotated.model {
279 obj.insert("model".into(), Json::String(model.clone()));
280 }
281
282 if let Some(ref params) = annotated.params {
283 overlay_generation_params(obj, params)?;
284 }
285
286 if let Some(ref tools) = annotated.tools {
287 insert_serialized(obj, "tools", tools, "tools")?;
288 }
289
290 if let Some(ref tool_choice) = annotated.tool_choice {
291 insert_serialized(obj, "tool_choice", tool_choice, "tool_choice")?;
292 }
293
294 for (k, v) in &annotated.extra {
295 obj.insert(k.clone(), v.clone());
296 }
297
298 let is_streaming = obj.get("stream").and_then(|v| v.as_bool()).unwrap_or(false);
313 if is_streaming && !obj.contains_key("stream_options") {
314 obj.insert(
315 "stream_options".into(),
316 serde_json::json!({"include_usage": true}),
317 );
318 }
319
320 Ok(LlmRequest {
321 headers: original.headers.clone(),
322 content,
323 })
324 }
325}
326
327fn json_f64(v: f64) -> Json {
329 serde_json::Number::from_f64(v)
330 .map(Json::Number)
331 .unwrap_or(Json::Null)
332}
333
334fn insert_serialized<T: serde::Serialize>(
335 obj: &mut serde_json::Map<String, Json>,
336 key: &str,
337 value: &T,
338 context: &str,
339) -> Result<()> {
340 let json = serde_json::to_value(value)
341 .map_err(|e| FlowError::Internal(format!("OpenAI Chat {context} encode: {e}")))?;
342 obj.insert(key.into(), json);
343 Ok(())
344}
345
346fn overlay_generation_params(
347 obj: &mut serde_json::Map<String, Json>,
348 params: &GenerationParams,
349) -> Result<()> {
350 if let Some(temp) = params.temperature {
351 obj.insert("temperature".into(), json_f64(temp));
352 }
353 if let Some(top_p) = params.top_p {
354 obj.insert("top_p".into(), json_f64(top_p));
355 }
356 if let Some(ref stop) = params.stop {
357 insert_serialized(obj, "stop", stop, "stop")?;
358 }
359 if let Some(max_tokens) = params.max_tokens {
360 let key = if obj.contains_key("max_completion_tokens") {
361 "max_completion_tokens"
362 } else {
363 "max_tokens"
364 };
365 obj.insert(key.into(), Json::from(max_tokens));
366 }
367 Ok(())
368}
369
370#[cfg(test)]
375#[path = "../../tests/unit/codec/openai_chat_tests.rs"]
376mod tests;