1use serde::Deserialize;
10
11use crate::api::llm::LlmRequest;
12use crate::error::{FlowError, Result};
13use crate::json::Json;
14
15use super::request::{AnnotatedLlmRequest, GenerationParams, Message, ToolChoice, ToolDefinition};
16use super::response::{
17 AnnotatedLlmResponse, ApiSpecificResponse, FinishReason, ResponseToolCall, Usage,
18};
19use super::traits::{LlmCodec, LlmResponseCodec};
20
21pub struct OpenAIChatCodec;
27
28#[derive(Deserialize)]
33struct RawChatCompletion {
34 id: Option<String>,
35 model: Option<String>,
36 choices: Option<Vec<RawChoice>>,
37 usage: Option<RawChatUsage>,
38 system_fingerprint: Option<String>,
39 service_tier: Option<String>,
40 #[serde(flatten)]
41 extra: serde_json::Map<String, Json>,
42}
43
44#[derive(Deserialize)]
45struct RawChoice {
46 message: Option<RawMessage>,
47 finish_reason: Option<String>,
48 logprobs: Option<Json>,
49}
50
51#[derive(Deserialize)]
52struct RawMessage {
53 content: Option<String>,
54 tool_calls: Option<Vec<RawToolCall>>,
55}
56
57#[derive(Deserialize)]
58struct RawToolCall {
59 id: Option<String>,
60 function: Option<RawFunction>,
61}
62
63#[derive(Deserialize)]
64struct RawFunction {
65 name: Option<String>,
66 arguments: Option<String>,
67}
68
69#[derive(Deserialize)]
70struct RawChatUsage {
71 prompt_tokens: Option<u64>,
72 completion_tokens: Option<u64>,
73 total_tokens: Option<u64>,
74 prompt_tokens_details: Option<RawPromptTokensDetails>,
75}
76
77#[derive(Deserialize)]
78struct RawPromptTokensDetails {
79 cached_tokens: Option<u64>,
80}
81
82fn map_chat_finish_reason(reason: &str) -> FinishReason {
88 match reason {
89 "stop" => FinishReason::Complete,
90 "length" => FinishReason::Length,
91 "tool_calls" | "function_call" => FinishReason::ToolUse,
92 "content_filter" => FinishReason::ContentFilter,
93 other => FinishReason::Unknown(other.to_string()),
94 }
95}
96
97fn parse_arguments(arguments: &str) -> Json {
101 serde_json::from_str(arguments).unwrap_or_else(|_| Json::String(arguments.to_string()))
102}
103
104const MODELED_REQUEST_KEYS: &[&str] = &[
106 "messages",
107 "model",
108 "temperature",
109 "max_tokens",
110 "max_completion_tokens",
111 "top_p",
112 "stop",
113 "tools",
114 "tool_choice",
115 "store",
116 "user",
117 "metadata",
118 "service_tier",
119 "parallel_tool_calls",
120 "top_logprobs",
121 "stream",
122];
123
124impl LlmResponseCodec for OpenAIChatCodec {
129 fn decode_response(&self, response: &Json) -> Result<AnnotatedLlmResponse> {
130 let raw: RawChatCompletion = serde_json::from_value(response.clone())
131 .map_err(|e| FlowError::Internal(format!("OpenAI Chat response decode: {e}")))?;
132
133 let choice = raw.choices.as_ref().and_then(|c| c.first());
135
136 let message = choice
138 .and_then(|c| c.message.as_ref())
139 .and_then(|m| m.content.as_ref())
140 .map(|s| super::request::MessageContent::Text(s.clone()));
141
142 let tool_calls = choice
146 .and_then(|c| c.message.as_ref())
147 .and_then(|m| m.tool_calls.as_ref())
148 .map(|tcs| {
149 tcs.iter()
150 .filter_map(|tc| {
151 let func = tc.function.as_ref()?;
152 let name = func.name.as_ref()?;
153 Some(ResponseToolCall {
154 id: tc.id.clone().unwrap_or_default(),
155 name: name.clone(),
156 arguments: func
157 .arguments
158 .as_deref()
159 .map(parse_arguments)
160 .unwrap_or(Json::Object(Default::default())),
161 })
162 })
163 .collect::<Vec<_>>()
164 });
165
166 let finish_reason = choice
168 .and_then(|c| c.finish_reason.as_deref())
169 .map(map_chat_finish_reason);
170
171 let usage = raw.usage.map(|u| Usage {
173 prompt_tokens: u.prompt_tokens,
174 completion_tokens: u.completion_tokens,
175 total_tokens: u.total_tokens,
176 cache_read_tokens: u.prompt_tokens_details.and_then(|d| d.cached_tokens),
177 cache_write_tokens: None,
178 });
179
180 let logprobs = choice.and_then(|c| c.logprobs.clone());
182 let api_specific = Some(ApiSpecificResponse::OpenAIChat {
183 logprobs,
184 system_fingerprint: raw.system_fingerprint,
185 service_tier: raw.service_tier,
186 });
187
188 Ok(AnnotatedLlmResponse {
189 id: raw.id,
190 model: raw.model,
191 message,
192 tool_calls,
193 finish_reason,
194 usage,
195 api_specific,
196 extra: raw.extra,
197 })
198 }
199}
200
201impl LlmCodec for OpenAIChatCodec {
206 fn decode(&self, request: &LlmRequest) -> Result<AnnotatedLlmRequest> {
207 let obj = request
208 .content
209 .as_object()
210 .ok_or_else(|| FlowError::Internal("request content is not an object".into()))?;
211
212 let messages: Vec<Message> = obj
214 .get("messages")
215 .map(|v| serde_json::from_value(v.clone()).unwrap_or_default())
216 .unwrap_or_default();
217
218 let model = obj.get("model").and_then(|v| v.as_str()).map(String::from);
220
221 let temperature = obj.get("temperature").and_then(|v| v.as_f64());
223 let top_p = obj.get("top_p").and_then(|v| v.as_f64());
224 let stop = obj
225 .get("stop")
226 .and_then(|v| serde_json::from_value::<Vec<String>>(v.clone()).ok());
227
228 let max_tokens = obj
230 .get("max_completion_tokens")
231 .and_then(|v| v.as_u64())
232 .or_else(|| obj.get("max_tokens").and_then(|v| v.as_u64()));
233
234 let params =
235 if temperature.is_some() || max_tokens.is_some() || top_p.is_some() || stop.is_some() {
236 Some(GenerationParams {
237 temperature,
238 max_tokens,
239 top_p,
240 stop,
241 })
242 } else {
243 None
244 };
245
246 let tools: Option<Vec<ToolDefinition>> = obj
248 .get("tools")
249 .map(|v| serde_json::from_value(v.clone()))
250 .transpose()
251 .map_err(|e| FlowError::Internal(format!("OpenAI Chat tools decode: {e}")))?;
252
253 let tool_choice: Option<ToolChoice> = obj
255 .get("tool_choice")
256 .map(|v| serde_json::from_value(v.clone()))
257 .transpose()
258 .map_err(|e| FlowError::Internal(format!("OpenAI Chat tool_choice decode: {e}")))?;
259
260 let extra: serde_json::Map<String, Json> = obj
262 .iter()
263 .filter(|(k, _)| !MODELED_REQUEST_KEYS.contains(&k.as_str()))
264 .map(|(k, v)| (k.clone(), v.clone()))
265 .collect();
266
267 Ok(AnnotatedLlmRequest {
268 messages,
269 model,
270 params,
271 tools,
272 tool_choice,
273 store: obj.get("store").and_then(|v| v.as_bool()),
274 previous_response_id: None,
275 truncation: None,
276 reasoning: None,
277 include: None,
278 user: obj.get("user").and_then(|v| v.as_str()).map(String::from),
279 metadata: obj.get("metadata").cloned(),
280 service_tier: obj
281 .get("service_tier")
282 .and_then(|v| v.as_str())
283 .map(String::from),
284 parallel_tool_calls: obj.get("parallel_tool_calls").and_then(|v| v.as_bool()),
285 max_output_tokens: None,
286 max_tool_calls: None,
287 top_logprobs: obj.get("top_logprobs").and_then(|v| v.as_u64()),
288 stream: obj.get("stream").and_then(|v| v.as_bool()),
289 extra,
290 })
291 }
292
293 fn encode(&self, annotated: &AnnotatedLlmRequest, original: &LlmRequest) -> Result<LlmRequest> {
294 let mut content = original.content.clone();
295 let obj = content
296 .as_object_mut()
297 .ok_or_else(|| FlowError::Internal("original content is not an object".into()))?;
298
299 insert_serialized(obj, "messages", &annotated.messages, "messages")?;
300
301 if let Some(ref model) = annotated.model {
302 obj.insert("model".into(), Json::String(model.clone()));
303 }
304
305 if let Some(ref params) = annotated.params {
306 overlay_generation_params(obj, params)?;
307 }
308
309 if let Some(ref tools) = annotated.tools {
310 insert_serialized(obj, "tools", tools, "tools")?;
311 }
312
313 if let Some(ref tool_choice) = annotated.tool_choice {
314 insert_serialized(obj, "tool_choice", tool_choice, "tool_choice")?;
315 }
316
317 if let Some(store) = annotated.store {
318 obj.insert("store".into(), Json::Bool(store));
319 }
320 if let Some(ref user) = annotated.user {
321 obj.insert("user".into(), Json::String(user.clone()));
322 }
323 if let Some(ref metadata) = annotated.metadata {
324 obj.insert("metadata".into(), metadata.clone());
325 }
326 if let Some(ref service_tier) = annotated.service_tier {
327 obj.insert("service_tier".into(), Json::String(service_tier.clone()));
328 }
329 if let Some(parallel_tool_calls) = annotated.parallel_tool_calls {
330 obj.insert(
331 "parallel_tool_calls".into(),
332 Json::Bool(parallel_tool_calls),
333 );
334 }
335 if let Some(top_logprobs) = annotated.top_logprobs {
336 obj.insert("top_logprobs".into(), Json::from(top_logprobs));
337 }
338 if let Some(stream) = annotated.stream {
339 obj.insert("stream".into(), Json::Bool(stream));
340 }
341
342 for (k, v) in &annotated.extra {
343 obj.insert(k.clone(), v.clone());
344 }
345
346 let is_streaming = obj.get("stream").and_then(|v| v.as_bool()).unwrap_or(false);
361 if is_streaming && !obj.contains_key("stream_options") {
362 obj.insert(
363 "stream_options".into(),
364 serde_json::json!({"include_usage": true}),
365 );
366 }
367
368 Ok(LlmRequest {
369 headers: original.headers.clone(),
370 content,
371 })
372 }
373}
374
375fn json_f64(v: f64) -> Json {
377 serde_json::Number::from_f64(v)
378 .map(Json::Number)
379 .unwrap_or(Json::Null)
380}
381
382fn insert_serialized<T: serde::Serialize>(
383 obj: &mut serde_json::Map<String, Json>,
384 key: &str,
385 value: &T,
386 context: &str,
387) -> Result<()> {
388 let json = serde_json::to_value(value)
389 .map_err(|e| FlowError::Internal(format!("OpenAI Chat {context} encode: {e}")))?;
390 obj.insert(key.into(), json);
391 Ok(())
392}
393
394fn overlay_generation_params(
395 obj: &mut serde_json::Map<String, Json>,
396 params: &GenerationParams,
397) -> Result<()> {
398 if let Some(temp) = params.temperature {
399 obj.insert("temperature".into(), json_f64(temp));
400 }
401 if let Some(top_p) = params.top_p {
402 obj.insert("top_p".into(), json_f64(top_p));
403 }
404 if let Some(ref stop) = params.stop {
405 insert_serialized(obj, "stop", stop, "stop")?;
406 }
407 if let Some(max_tokens) = params.max_tokens {
408 let key = if obj.contains_key("max_completion_tokens") {
409 "max_completion_tokens"
410 } else {
411 "max_tokens"
412 };
413 obj.insert(key.into(), Json::from(max_tokens));
414 }
415 Ok(())
416}
417
418pub struct OpenAIChatStreamingCodec {
448 state: std::sync::Arc<std::sync::Mutex<OpenAIChatStreamingState>>,
449}
450
451impl OpenAIChatStreamingCodec {
452 pub fn new() -> Self {
454 Self {
455 state: std::sync::Arc::new(std::sync::Mutex::new(OpenAIChatStreamingState::default())),
456 }
457 }
458}
459
460impl Default for OpenAIChatStreamingCodec {
461 fn default() -> Self {
462 Self::new()
463 }
464}
465
466impl super::streaming::StreamingCodec for OpenAIChatStreamingCodec {
467 fn collector(&self) -> crate::api::runtime::LlmCollectorFn {
468 let state = std::sync::Arc::clone(&self.state);
469 Box::new(move |event: Json| -> Result<()> {
470 let mut guard = state
471 .lock()
472 .unwrap_or_else(|poisoned| poisoned.into_inner());
473 guard.observe(&event);
474 Ok(())
475 })
476 }
477
478 fn finalizer(&self) -> crate::api::runtime::LlmFinalizerFn {
479 let state = std::sync::Arc::clone(&self.state);
480 Box::new(move || -> Json {
481 let mut guard = state
482 .lock()
483 .unwrap_or_else(|poisoned| poisoned.into_inner());
484 std::mem::take(&mut *guard).finalize()
485 })
486 }
487}
488
489#[derive(Debug, Default)]
490struct OpenAIChatStreamingState {
491 id: Option<String>,
492 object: Option<String>,
493 created: Option<u64>,
494 model: Option<String>,
495 choices: std::collections::BTreeMap<u64, ChoiceState>,
498 usage: Option<Json>,
500}
501
502#[derive(Debug, Default)]
503struct ChoiceState {
504 role: Option<String>,
505 content: String,
506 has_content: bool,
507 tool_calls: std::collections::BTreeMap<u64, ToolCallState>,
510 finish_reason: Option<String>,
511}
512
513#[derive(Debug, Default)]
514struct ToolCallState {
515 id: Option<String>,
516 type_: Option<String>,
517 name: Option<String>,
518 arguments: String,
519}
520
521impl OpenAIChatStreamingState {
522 fn observe(&mut self, chunk: &Json) {
523 if self.id.is_none()
526 && let Some(id) = chunk.get("id").and_then(Json::as_str)
527 {
528 self.id = Some(id.to_string());
529 }
530 if self.object.is_none()
531 && let Some(obj) = chunk.get("object").and_then(Json::as_str)
532 {
533 self.object = Some(obj.to_string());
534 }
535 if self.created.is_none()
536 && let Some(c) = chunk.get("created").and_then(Json::as_u64)
537 {
538 self.created = Some(c);
539 }
540 if self.model.is_none()
541 && let Some(m) = chunk.get("model").and_then(Json::as_str)
542 {
543 self.model = Some(m.to_string());
544 }
545 if let Some(usage) = chunk.get("usage") {
546 if !usage.is_null() {
549 self.usage = Some(usage.clone());
550 }
551 }
552 let Some(choices) = chunk.get("choices").and_then(Json::as_array) else {
553 return;
554 };
555 for choice in choices {
556 self.observe_choice(choice);
557 }
558 }
559
560 fn observe_choice(&mut self, choice: &Json) {
561 let index = choice.get("index").and_then(Json::as_u64).unwrap_or(0);
562 let entry = self.choices.entry(index).or_default();
563 entry.observe_finish_reason(choice);
564 entry.observe_delta(choice.get("delta"));
565 }
566
567 fn finalize(self) -> Json {
568 let mut output = serde_json::Map::new();
569 if let Some(id) = self.id {
570 output.insert("id".to_string(), Json::String(id));
571 }
572 if let Some(object) = self.object {
577 let normalized = object
578 .strip_suffix(".chunk")
579 .map(str::to_string)
580 .unwrap_or(object);
581 output.insert("object".to_string(), Json::String(normalized));
582 }
583 if let Some(created) = self.created {
584 output.insert("created".to_string(), Json::Number(created.into()));
585 }
586 if let Some(model) = self.model {
587 output.insert("model".to_string(), Json::String(model));
588 }
589 let choices: Vec<Json> = self
590 .choices
591 .into_iter()
592 .map(|(index, choice)| choice.finalize(index))
593 .collect();
594 output.insert("choices".to_string(), Json::Array(choices));
595 if let Some(usage) = self.usage {
596 output.insert("usage".to_string(), usage);
597 }
598 Json::Object(output)
599 }
600}
601
602impl ChoiceState {
603 fn observe_finish_reason(&mut self, choice: &Json) {
604 if let Some(reason) = choice.get("finish_reason").and_then(Json::as_str) {
605 self.finish_reason = Some(reason.to_string());
606 }
607 }
608
609 fn observe_delta(&mut self, delta: Option<&Json>) {
610 let Some(delta) = delta else {
611 return;
612 };
613 if let Some(role) = delta.get("role").and_then(Json::as_str) {
614 self.role = Some(role.to_string());
615 }
616 if let Some(content) = delta.get("content").and_then(Json::as_str) {
617 self.content.push_str(content);
618 self.has_content = true;
619 }
620 self.observe_tool_calls(delta);
621 }
622
623 fn observe_tool_calls(&mut self, delta: &Json) {
624 if let Some(tool_calls) = delta.get("tool_calls").and_then(Json::as_array) {
625 for tool_call in tool_calls {
626 self.observe_tool_call(tool_call);
627 }
628 }
629 }
630
631 fn observe_tool_call(&mut self, tool_call: &Json) {
632 let index = tool_call.get("index").and_then(Json::as_u64).unwrap_or(0);
633 let state = self.tool_calls.entry(index).or_default();
634 if let Some(id) = tool_call.get("id").and_then(Json::as_str) {
635 state.id = Some(id.to_string());
636 }
637 if let Some(type_) = tool_call.get("type").and_then(Json::as_str) {
638 state.type_ = Some(type_.to_string());
639 }
640 if let Some(function) = tool_call.get("function") {
641 state.observe_function(function);
642 }
643 }
644
645 fn finalize(self, index: u64) -> Json {
646 let mut message = serde_json::Map::new();
647 message.insert(
648 "role".to_string(),
649 Json::String(self.role.unwrap_or_else(|| "assistant".to_string())),
650 );
651 if self.has_content {
655 message.insert("content".to_string(), Json::String(self.content));
656 } else {
657 message.insert("content".to_string(), Json::Null);
658 }
659 if !self.tool_calls.is_empty() {
660 let tool_calls: Vec<Json> = self
661 .tool_calls
662 .into_values()
663 .map(ToolCallState::finalize)
664 .collect();
665 message.insert("tool_calls".to_string(), Json::Array(tool_calls));
666 }
667 let mut choice = serde_json::Map::new();
668 choice.insert("index".to_string(), Json::Number(index.into()));
669 choice.insert("message".to_string(), Json::Object(message));
670 if let Some(reason) = self.finish_reason {
671 choice.insert("finish_reason".to_string(), Json::String(reason));
672 } else {
673 choice.insert("finish_reason".to_string(), Json::Null);
674 }
675 Json::Object(choice)
676 }
677}
678
679impl ToolCallState {
680 fn observe_function(&mut self, function: &Json) {
681 if let Some(name) = function.get("name").and_then(Json::as_str) {
682 self.name = Some(name.to_string());
683 }
684 if let Some(args) = function.get("arguments").and_then(Json::as_str) {
685 self.arguments.push_str(args);
686 }
687 }
688
689 fn finalize(self) -> Json {
690 let mut function = serde_json::Map::new();
691 function.insert(
692 "name".to_string(),
693 Json::String(self.name.unwrap_or_default()),
694 );
695 function.insert("arguments".to_string(), Json::String(self.arguments));
696 let mut call = serde_json::Map::new();
697 if let Some(id) = self.id {
698 call.insert("id".to_string(), Json::String(id));
699 }
700 call.insert(
701 "type".to_string(),
702 Json::String(self.type_.unwrap_or_else(|| "function".to_string())),
703 );
704 call.insert("function".to_string(), Json::Object(function));
705 Json::Object(call)
706 }
707}
708
709#[cfg(test)]
714#[path = "../../tests/unit/codec/openai_chat_tests.rs"]
715mod tests;