1use std::collections::HashMap;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use defect_core::error::BoxError;
19use defect_core::llm::{
20 CompletionRequest, ImageData, Message, MessageContent, ProviderChunk, ProviderError,
21 ProviderErrorKind, ReasoningEffort, Role, StopReason, ThinkingConfig, ThinkingEcho, ToolChoice,
22 ToolResultBody, ToolResultContent, Usage,
23};
24use defect_core::tool::ToolSchema;
25use futures::Stream;
26use sse_stream::Sse;
27use toac::body::codec::sse::SseEventStream;
28use tokio_util::sync::CancellationToken;
29use tracing::warn;
30
31use crate::wire::openai::components as wire;
32
33const PROMPT_CACHE_KEY_PREFIX: &str = "defect:chat:v1:";
36const PROMPT_CACHE_KEY_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
37const PROMPT_CACHE_KEY_PRIME: u64 = 0x0000_0001_0000_01b3;
38
39type UsageParser = fn(Option<&serde_json::Value>, &wire::CompletionUsage) -> Usage;
40
41pub fn encode_request(req: &CompletionRequest) -> wire::CreateChatCompletionRequest {
64 encode_request_with_echo(req, ThinkingEcho::Forbidden)
65}
66
67pub fn encode_request_with_echo(
74 req: &CompletionRequest,
75 echo_mode: ThinkingEcho,
76) -> wire::CreateChatCompletionRequest {
77 encode_request_full(req, echo_mode, None)
78}
79
80pub fn encode_request_full(
85 req: &CompletionRequest,
86 echo_mode: ThinkingEcho,
87 effort_override: Option<ReasoningEffort>,
88) -> wire::CreateChatCompletionRequest {
89 encode_request_with_dialect(req, echo_mode, effort_override, ChatDialect::OpenAi)
90}
91
92#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
97pub enum ChatDialect {
98 #[default]
99 OpenAi,
100 DeepSeek,
101}
102
103pub fn encode_request_with_dialect(
106 req: &CompletionRequest,
107 echo_mode: ThinkingEcho,
108 effort_override: Option<ReasoningEffort>,
109 dialect: ChatDialect,
110) -> wire::CreateChatCompletionRequest {
111 let mut messages = Vec::with_capacity(req.messages.len() + 1);
112 if let Some(sys) = req.system.as_ref() {
113 messages.push(encode_system_message(sys));
114 }
115 for m in &req.messages {
116 encode_message_into(m, echo_mode, dialect, &mut messages);
117 }
118
119 let max_tokens = req.sampling.max_tokens.map(i64::from);
120 #[allow(deprecated)]
121 wire::CreateChatCompletionRequest {
122 messages,
124 model: wire::ModelIdsShared::ModelIdsSharedVariant0(req.model.clone()),
125 stream: Some(true),
126 stream_options: Some(wire::ChatCompletionStreamOptions::ChatCompletionStreamOptionsVariant0(
127 wire::ChatCompletionStreamOptionsVariant0 {
128 include_usage: Some(true),
129 include_obfuscation: None,
130 },
131 )),
132 max_completion_tokens: match dialect {
133 ChatDialect::OpenAi => max_tokens,
134 ChatDialect::DeepSeek => None,
135 },
136 temperature: req.sampling.temperature.map(|t| {
137 wire::CreateChatCompletionRequestTemperature::CreateChatCompletionRequestTemperatureVariant0(
138 f64::from(t),
139 )
140 }),
141 top_p: req.sampling.top_p.map(|t| {
142 wire::CreateChatCompletionRequestTopP::CreateChatCompletionRequestTopPVariant0(
143 f64::from(t),
144 )
145 }),
146 stop: if req.sampling.stop_sequences.is_empty() {
147 None
148 } else {
149 Some(wire::StopConfiguration::StopConfigurationVariant1(
150 req.sampling.stop_sequences.clone(),
151 ))
152 },
153 reasoning_effort: req
158 .sampling
159 .reasoning_effort
160 .or(effort_override)
161 .map(encode_reasoning_effort)
162 .or_else(|| encode_thinking(req.sampling.thinking)),
163 tools: if req.tools.is_empty() {
164 None
165 } else {
166 Some(req.tools.iter().map(encode_tool).collect())
167 },
168 tool_choice: encode_tool_choice(&req.tool_choice),
169 metadata: None,
171 top_logprobs: None,
172 user: None,
173 safety_identifier: None,
174 prompt_cache_key: match dialect {
175 ChatDialect::OpenAi => Some(build_prompt_cache_key(req, echo_mode)),
176 ChatDialect::DeepSeek => None,
177 },
178 service_tier: None,
179 prompt_cache_retention: None,
180 modalities: None,
181 verbosity: None,
182 frequency_penalty: None,
183 presence_penalty: None,
184 web_search_options: None,
185 response_format: None,
186 audio: None,
187 store: None,
188 logit_bias: None,
189 logprobs: None,
190 max_tokens: match dialect {
191 ChatDialect::OpenAi => None,
192 ChatDialect::DeepSeek => max_tokens,
193 },
194 n: None,
195 prediction: None,
196 seed: None,
197 parallel_tool_calls: None,
198 function_call: None,
199 functions: None,
200 }
201}
202
203fn build_prompt_cache_key(req: &CompletionRequest, echo_mode: ThinkingEcho) -> String {
204 let mut hasher = PromptCacheKeyHasher::new();
205 hasher.write_str(&req.model);
206 if let Some(system) = req.system.as_deref() {
207 hasher.write_str(system);
208 }
209 hasher.write_str(prompt_cache_echo_mode(echo_mode));
210 hasher.write_str(prompt_cache_tool_choice(&req.tool_choice));
211 hasher.write_json(&req.tools);
212 format!("{PROMPT_CACHE_KEY_PREFIX}{:016x}", hasher.finish())
213}
214
215fn prompt_cache_echo_mode(mode: ThinkingEcho) -> &'static str {
216 match mode {
217 ThinkingEcho::Forbidden => "forbidden",
218 ThinkingEcho::Required => "required",
219 ThinkingEcho::Optional => "optional",
220 }
221}
222
223fn prompt_cache_tool_choice(choice: &ToolChoice) -> &str {
224 match choice {
225 ToolChoice::Auto => "auto",
226 ToolChoice::Required => "required",
227 ToolChoice::Named { name } => name.as_str(),
228 ToolChoice::None => "none",
229 }
230}
231
232struct PromptCacheKeyHasher {
233 state: u64,
234}
235
236impl PromptCacheKeyHasher {
237 fn new() -> Self {
238 Self {
239 state: PROMPT_CACHE_KEY_OFFSET_BASIS,
240 }
241 }
242
243 fn write_json<T>(&mut self, value: &T)
244 where
245 T: serde::Serialize,
246 {
247 let Ok(encoded) = serde_json::to_vec(value) else {
248 return;
249 };
250 self.write_bytes(&encoded);
251 }
252
253 fn write_str(&mut self, value: &str) {
254 self.write_bytes(value.as_bytes());
255 }
256
257 fn write_bytes(&mut self, bytes: &[u8]) {
258 for byte in bytes {
259 self.state ^= u64::from(*byte);
260 self.state = self.state.wrapping_mul(PROMPT_CACHE_KEY_PRIME);
261 }
262 self.state ^= u64::from(b'\n');
263 self.state = self.state.wrapping_mul(PROMPT_CACHE_KEY_PRIME);
264 }
265
266 fn finish(self) -> u64 {
267 self.state
268 }
269}
270
271fn encode_system_message(text: &str) -> wire::ChatCompletionRequestMessage {
272 wire::ChatCompletionRequestMessage::ChatCompletionRequestSystemMessage(
273 wire::ChatCompletionRequestSystemMessage {
274 content: wire::ChatCompletionRequestSystemMessageContent::ChatCompletionRequestSystemMessageContentVariant0(
275 text.to_owned(),
276 ),
277 role: wire::ChatCompletionRequestSystemMessageRole::System,
278 name: None,
279 },
280 )
281}
282
283fn encode_message_into(
289 m: &Message,
290 echo_mode: ThinkingEcho,
291 dialect: ChatDialect,
292 out: &mut Vec<wire::ChatCompletionRequestMessage>,
293) {
294 match m.role {
295 Role::User => encode_user_message_into(m, out),
296 Role::Assistant => encode_assistant_message_into(m, echo_mode, dialect, out),
297 }
298}
299
300fn encode_user_message_into(m: &Message, out: &mut Vec<wire::ChatCompletionRequestMessage>) {
301 let mut user_parts: Vec<wire::ChatCompletionRequestUserMessageContentPart> = Vec::new();
302 let mut tool_results: Vec<(String, String)> = Vec::new(); for c in m.content.iter() {
305 match c {
306 MessageContent::Text { text } => {
307 user_parts.push(
308 wire::ChatCompletionRequestUserMessageContentPart::ChatCompletionRequestMessageContentPartText(
309 wire::ChatCompletionRequestMessageContentPartText {
310 r#type: wire::ChatCompletionRequestMessageContentPartTextType::Text,
311 text: text.clone(),
312 },
313 ),
314 );
315 }
316 MessageContent::Image { mime, data } => {
317 user_parts.push(image_part(mime, data));
318 }
319 MessageContent::ToolResult {
320 tool_use_id,
321 output,
322 is_error: _,
323 } => {
324 let text = match output {
335 ToolResultBody::Text { text } => text.clone(),
336 ToolResultBody::Json { value } => value.to_string(),
337 ToolResultBody::Content { blocks } => {
338 let mut text = String::new();
339 let mut image_count = 0usize;
340 for block in blocks {
341 match block {
342 ToolResultContent::Text { text: t } => {
343 if !text.is_empty() {
344 text.push('\n');
345 }
346 text.push_str(t);
347 }
348 ToolResultContent::Image { mime, data } => {
349 image_count += 1;
350 user_parts.push(image_part(mime, data));
351 }
352 }
353 }
354 if image_count > 0 {
355 if !text.is_empty() {
356 text.push('\n');
357 }
358 text.push_str(&format!(
359 "[{image_count} image(s) from this tool result follow in the next user message]"
360 ));
361 }
362 text
363 }
364 };
365 tool_results.push((tool_use_id.clone(), text));
366 }
367 _ => {
369 user_parts.push(
370 wire::ChatCompletionRequestUserMessageContentPart::ChatCompletionRequestMessageContentPartText(
371 wire::ChatCompletionRequestMessageContentPartText {
372 r#type: wire::ChatCompletionRequestMessageContentPartTextType::Text,
373 text: String::new(),
374 },
375 ),
376 );
377 }
378 }
379 }
380
381 for (tool_use_id, text) in tool_results {
385 out.push(wire::ChatCompletionRequestMessage::ChatCompletionRequestToolMessage(
386 wire::ChatCompletionRequestToolMessage {
387 role: wire::ChatCompletionRequestToolMessageRole::Tool,
388 content: wire::ChatCompletionRequestToolMessageContent::ChatCompletionRequestToolMessageContentVariant0(
389 text,
390 ),
391 tool_call_id: tool_use_id,
392 },
393 ));
394 }
395 if !user_parts.is_empty() {
396 out.push(wire::ChatCompletionRequestMessage::ChatCompletionRequestUserMessage(
397 wire::ChatCompletionRequestUserMessage {
398 content: wire::ChatCompletionRequestUserMessageContent::ChatCompletionRequestUserMessageContentVariant1(
399 user_parts,
400 ),
401 role: wire::ChatCompletionRequestUserMessageRole::User,
402 name: None,
403 },
404 ));
405 }
406}
407
408fn encode_assistant_message_into(
409 m: &Message,
410 echo_mode: ThinkingEcho,
411 dialect: ChatDialect,
412 out: &mut Vec<wire::ChatCompletionRequestMessage>,
413) {
414 const EMPTY_ASSISTANT_CONTENT: &str = "";
415
416 let mut text_parts: Vec<String> = Vec::new();
417 let mut tool_calls: Vec<wire::ChatCompletionMessageToolCallsItem> = Vec::new();
418 let mut reasoning_text = String::new();
419
420 for c in m.content.iter() {
421 match c {
422 MessageContent::Text { text } => text_parts.push(text.clone()),
423 MessageContent::Thinking { text, .. } => {
424 reasoning_text.push_str(text);
427 }
428 MessageContent::ToolUse { id, name, args } => {
429 tool_calls.push(
430 wire::ChatCompletionMessageToolCallsItem::ChatCompletionMessageToolCall(
431 wire::ChatCompletionMessageToolCall {
432 id: id.clone(),
433 r#type: wire::ChatCompletionMessageToolCallType::Function,
434 function: wire::ChatCompletionMessageToolCallFunction {
435 name: name.clone(),
436 arguments: serde_json::to_string(args).unwrap_or_default(),
437 },
438 },
439 ),
440 );
441 }
442 _ => {}
446 }
447 }
448
449 let reasoning_content = match dialect {
450 ChatDialect::DeepSeek => Some(reasoning_text),
451 ChatDialect::OpenAi => match (echo_mode, reasoning_text.is_empty()) {
452 (ThinkingEcho::Required, false) => Some(reasoning_text),
453 (ThinkingEcho::Optional, false) => Some(reasoning_text),
457 _ => None,
458 },
459 };
460 let content = if text_parts.is_empty() {
461 if tool_calls.is_empty() && reasoning_content.is_some() {
462 Some(wire::ChatCompletionRequestAssistantMessageContent::ChatCompletionRequestAssistantMessageContentVariant0(
466 wire::ChatCompletionRequestAssistantMessageContentVariant0::ChatCompletionRequestAssistantMessageContentVariant0Variant0(
467 EMPTY_ASSISTANT_CONTENT.to_owned(),
468 ),
469 ))
470 } else {
471 None
472 }
473 } else {
474 Some(wire::ChatCompletionRequestAssistantMessageContent::ChatCompletionRequestAssistantMessageContentVariant0(
475 wire::ChatCompletionRequestAssistantMessageContentVariant0::ChatCompletionRequestAssistantMessageContentVariant0Variant0(
476 text_parts.join(""),
477 ),
478 ))
479 };
480
481 #[allow(deprecated)]
482 out.push(
483 wire::ChatCompletionRequestMessage::ChatCompletionRequestAssistantMessage(
484 wire::ChatCompletionRequestAssistantMessage {
485 content,
486 refusal: None,
487 role: wire::ChatCompletionRequestAssistantMessageRole::Assistant,
488 name: None,
489 audio: None,
490 tool_calls: if tool_calls.is_empty() {
491 None
492 } else {
493 Some(tool_calls)
494 },
495 function_call: None,
496 reasoning_content,
497 },
498 ),
499 );
500}
501
502fn image_part(mime: &str, data: &ImageData) -> wire::ChatCompletionRequestUserMessageContentPart {
505 wire::ChatCompletionRequestUserMessageContentPart::ChatCompletionRequestMessageContentPartImage(
506 wire::ChatCompletionRequestMessageContentPartImage {
507 r#type: wire::ChatCompletionRequestMessageContentPartImageType::ImageUrl,
508 image_url: wire::ChatCompletionRequestMessageContentPartImageImageUrl {
509 url: image_url_string(mime, data),
510 detail: None,
511 },
512 },
513 )
514}
515
516fn image_url_string(mime: &str, data: &ImageData) -> String {
517 match data {
518 ImageData::Url { url } => url.clone(),
519 ImageData::Base64 { encoded } => format!("data:{mime};base64,{encoded}"),
520 }
521}
522
523fn encode_thinking(t: ThinkingConfig) -> Option<wire::ReasoningEffort> {
524 match t {
525 ThinkingConfig::Disabled => None,
526 ThinkingConfig::Enabled { .. } => Some(wire::ReasoningEffort::ReasoningEffortVariant0(
530 wire::ReasoningEffortVariant0::Medium,
531 )),
532 }
533}
534
535fn encode_reasoning_effort(effort: ReasoningEffort) -> wire::ReasoningEffort {
536 use ReasoningEffort as E;
537 use wire::ReasoningEffortVariant0 as V;
538 let v = match effort {
539 E::None => V::None,
540 E::Minimal => V::Minimal,
541 E::Low => V::Low,
542 E::Medium => V::Medium,
543 E::High => V::High,
544 E::Xhigh => V::Xhigh,
545 };
546 wire::ReasoningEffort::ReasoningEffortVariant0(v)
547}
548
549fn encode_tool_choice(c: &ToolChoice) -> Option<wire::ChatCompletionToolChoiceOption> {
550 match c {
551 ToolChoice::Auto => Some(
552 wire::ChatCompletionToolChoiceOption::ChatCompletionToolChoiceOptionVariant0(
553 wire::ChatCompletionToolChoiceOptionVariant0::Auto,
554 ),
555 ),
556 ToolChoice::Required => Some(
557 wire::ChatCompletionToolChoiceOption::ChatCompletionToolChoiceOptionVariant0(
558 wire::ChatCompletionToolChoiceOptionVariant0::Required,
559 ),
560 ),
561 ToolChoice::None => Some(
562 wire::ChatCompletionToolChoiceOption::ChatCompletionToolChoiceOptionVariant0(
563 wire::ChatCompletionToolChoiceOptionVariant0::None,
564 ),
565 ),
566 ToolChoice::Named { name } => Some(
567 wire::ChatCompletionToolChoiceOption::ChatCompletionNamedToolChoice(
568 wire::ChatCompletionNamedToolChoice {
569 r#type: wire::ChatCompletionNamedToolChoiceType::Function,
570 function: wire::ChatCompletionNamedToolChoiceFunction { name: name.clone() },
571 },
572 ),
573 ),
574 }
575}
576
577fn encode_tool(t: &ToolSchema) -> wire::CreateChatCompletionRequestTools {
578 wire::CreateChatCompletionRequestTools::ChatCompletionTool(wire::ChatCompletionTool {
579 r#type: wire::ChatCompletionToolType::Function,
580 function: wire::FunctionObject {
581 name: t.name.clone(),
582 description: if t.description.is_empty() {
583 None
584 } else {
585 Some(t.description.clone())
586 },
587 parameters: Some(json_value_to_parameters(&t.input_schema)),
588 strict: None,
589 },
590 })
591}
592
593fn json_value_to_parameters(v: &serde_json::Value) -> wire::FunctionParameters {
594 v.as_object()
595 .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
596 .unwrap_or_default()
597}
598
599#[derive(Debug, Default)]
603struct DecoderState {
604 started: bool,
606 stopped: bool,
608 done: bool,
610 fatal: bool,
612 tool_calls: HashMap<i64, ToolCallState>,
617 tool_call_order: Vec<i64>,
620}
621
622#[derive(Debug, Clone)]
623struct ToolCallState {
624 id: String,
625 closed: bool,
627}
628
629pub fn decode_stream(
635 sse: SseEventStream,
636 cancel: CancellationToken,
637) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send {
638 decode_stream_with_usage_parser(sse, cancel, usage_from_wire)
639}
640
641pub fn decode_stream_generic<S, E>(
645 sse: S,
646 cancel: CancellationToken,
647) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send
648where
649 S: Stream<Item = Result<Sse, E>> + Send + 'static,
650 E: std::error::Error + Send + Sync + 'static,
651{
652 decode_stream_generic_with_usage_parser(sse, cancel, usage_from_wire)
653}
654
655pub(crate) fn decode_stream_with_usage_parser(
658 sse: SseEventStream,
659 cancel: CancellationToken,
660 usage_parser: UsageParser,
661) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send {
662 decode_stream_generic_with_usage_parser(sse, cancel, usage_parser)
663}
664
665fn decode_stream_generic_with_usage_parser<S, E>(
666 sse: S,
667 cancel: CancellationToken,
668 usage_parser: UsageParser,
669) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send
670where
671 S: Stream<Item = Result<Sse, E>> + Send + 'static,
672 E: std::error::Error + Send + Sync + 'static,
673{
674 OpenAiSseDecoder {
675 inner: sse,
676 cancel,
677 state: DecoderState::default(),
678 pending: Vec::new(),
679 finished: false,
680 usage_parser,
681 _err: std::marker::PhantomData::<E>,
682 }
683}
684
685struct OpenAiSseDecoder<S, E> {
686 inner: S,
687 cancel: CancellationToken,
688 state: DecoderState,
689 pending: Vec<Result<ProviderChunk, ProviderError>>,
693 finished: bool,
694 usage_parser: UsageParser,
695 _err: std::marker::PhantomData<E>,
696}
697
698impl<S, E> Stream for OpenAiSseDecoder<S, E>
699where
700 S: Stream<Item = Result<Sse, E>>,
701 E: std::error::Error + Send + Sync + 'static,
702{
703 type Item = Result<ProviderChunk, ProviderError>;
704
705 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
706 let this = unsafe { self.get_unchecked_mut() };
709 loop {
710 if let Some(item) = this.pending.pop() {
711 return Poll::Ready(Some(item));
712 }
713 if this.finished {
714 return Poll::Ready(None);
715 }
716 if this.cancel.is_cancelled() {
717 this.finished = true;
718 return Poll::Ready(None);
719 }
720
721 let inner = unsafe { Pin::new_unchecked(&mut this.inner) };
723 match inner.poll_next(cx) {
724 Poll::Pending => return Poll::Pending,
725 Poll::Ready(None) => {
726 this.finished = true;
727 if !this.state.done
729 && !this.state.stopped
730 && this.state.started
731 && !this.state.fatal
732 {
733 return Poll::Ready(Some(Err(ProviderError::new(
734 ProviderErrorKind::ProtocolViolation {
735 hint: "stream ended without finish_reason or [DONE]".into(),
736 },
737 ))));
738 }
739 return Poll::Ready(None);
740 }
741 Poll::Ready(Some(Err(e))) => {
742 this.finished = true;
743 return Poll::Ready(Some(Err(ProviderError::new(
744 ProviderErrorKind::Transport(BoxError::new(e)),
745 ))));
746 }
747 Poll::Ready(Some(Ok(sse))) => {
748 process_sse(&mut this.state, sse, &mut this.pending, this.usage_parser);
749 if this.state.done || this.state.fatal {
750 this.finished = true;
751 }
752 }
753 }
754 }
755 }
756}
757
758fn process_sse(
759 state: &mut DecoderState,
760 sse: Sse,
761 out: &mut Vec<Result<ProviderChunk, ProviderError>>,
762 usage_parser: UsageParser,
763) {
764 let data = match sse.data {
765 Some(d) => d,
766 None => return,
767 };
768 let trimmed = data.trim();
769 if trimmed == "[DONE]" {
772 state.done = true;
773 return;
774 }
775
776 let raw: serde_json::Value = match serde_json::from_str(trimmed) {
780 Ok(v) => v,
781 Err(e) => {
782 out.push(Err(ProviderError::new(ProviderErrorKind::Malformed(
783 BoxError::new(e),
784 ))));
785 return;
786 }
787 };
788
789 let parsed: Result<wire::CreateChatCompletionStreamResponse, _> =
790 serde_json::from_value(raw.clone());
791 let evt = match parsed {
792 Ok(e) => e,
793 Err(e) => {
794 out.push(Err(ProviderError::new(ProviderErrorKind::Malformed(
795 BoxError::new(e),
796 ))));
797 return;
798 }
799 };
800
801 handle_chunk(state, &raw, evt, out, usage_parser);
802}
803
804fn handle_chunk(
805 state: &mut DecoderState,
806 raw: &serde_json::Value,
807 evt: wire::CreateChatCompletionStreamResponse,
808 out: &mut Vec<Result<ProviderChunk, ProviderError>>,
809 usage_parser: UsageParser,
810) {
811 let mut buf: Vec<Result<ProviderChunk, ProviderError>> = Vec::new();
814
815 if !state.started {
819 state.started = true;
820 buf.push(Ok(ProviderChunk::MessageStart {
821 id: evt.id.clone(),
822 model: evt.model.clone(),
823 }));
824 }
825
826 for (choice_idx, choice) in evt.choices.iter().enumerate() {
828 let raw_delta = raw
830 .get("choices")
831 .and_then(|v| v.as_array())
832 .and_then(|a| a.get(choice_idx))
833 .and_then(|c| c.get("delta"));
834
835 let delta = &choice.delta;
836
837 if let Some(rc) = raw_delta
840 .and_then(|d| d.get("reasoning_content"))
841 .and_then(|v| v.as_str())
842 && !rc.is_empty()
843 {
844 buf.push(Ok(ProviderChunk::ThinkingDelta {
845 text: rc.to_owned(),
846 }));
847 }
848
849 if let Some(
851 wire::ChatCompletionStreamResponseDeltaContent::ChatCompletionStreamResponseDeltaContentVariant0(
852 s,
853 ),
854 ) = &delta.content
855 && !s.is_empty()
856 {
857 buf.push(Ok(ProviderChunk::TextDelta { text: s.clone() }));
858 }
859
860 if let Some(calls) = &delta.tool_calls {
863 for tc in calls {
864 handle_tool_call_chunk(state, tc, &mut buf);
865 }
866 }
867
868 if let Some(
872 wire::ChatCompletionStreamResponseDeltaRefusal::ChatCompletionStreamResponseDeltaRefusalVariant0(
873 s,
874 ),
875 ) = &delta.refusal
876 && !s.is_empty()
877 {
878 buf.push(Ok(ProviderChunk::TextDelta { text: s.clone() }));
879 }
880
881 if !state.stopped
896 && let Some(fr) = choice.finish_reason
897 {
898 let order = state.tool_call_order.clone();
899 for idx in order {
900 if let Some(tc) = state.tool_calls.get_mut(&idx)
901 && !tc.closed
902 {
903 tc.closed = true;
904 buf.push(Ok(ProviderChunk::ToolUseEnd { id: tc.id.clone() }));
905 }
906 }
907 state.stopped = true;
908 buf.push(Ok(ProviderChunk::Stop {
909 reason: stop_reason_from_wire(fr),
910 }));
911 }
912 }
913
914 if let Some(usage) = &evt.usage {
916 buf.push(Ok(ProviderChunk::Usage(usage_parser(
917 raw.get("usage"),
918 usage,
919 ))));
920 }
921
922 buf.reverse();
923 out.extend(buf);
924}
925
926fn handle_tool_call_chunk(
927 state: &mut DecoderState,
928 tc: &wire::ChatCompletionMessageToolCallChunk,
929 out: &mut Vec<Result<ProviderChunk, ProviderError>>,
930) {
931 let idx = tc.index;
932 let entry_existed = state.tool_calls.contains_key(&idx);
933
934 if !entry_existed {
938 let Some(id) = tc.id.clone() else {
939 warn!(index = idx, "tool_calls chunk missing id on first frame");
943 return;
944 };
945 let name = tc
946 .function
947 .as_ref()
948 .and_then(|f| f.name.clone())
949 .unwrap_or_default();
950 state.tool_calls.insert(
951 idx,
952 ToolCallState {
953 id: id.clone(),
954 closed: false,
955 },
956 );
957 state.tool_call_order.push(idx);
958 out.push(Ok(ProviderChunk::ToolUseStart { id, name }));
959 }
960
961 if let Some(func) = &tc.function
962 && let Some(args) = &func.arguments
963 && !args.is_empty()
964 && let Some(tool) = state.tool_calls.get(&idx)
965 {
966 out.push(Ok(ProviderChunk::ToolUseArgsDelta {
967 id: tool.id.clone(),
968 fragment: args.clone(),
969 }));
970 }
971}
972
973fn stop_reason_from_wire(
974 r: wire::CreateChatCompletionStreamResponseChoicesFinishReason,
975) -> StopReason {
976 use wire::CreateChatCompletionStreamResponseChoicesFinishReason as W;
977 match r {
978 W::Stop => StopReason::EndTurn,
979 W::Length => StopReason::MaxTokens,
980 W::ToolCalls | W::FunctionCall => StopReason::ToolUse,
981 W::ContentFilter => StopReason::Refusal,
982 }
983}
984
985fn usage_from_wire(_raw_usage: Option<&serde_json::Value>, u: &wire::CompletionUsage) -> Usage {
986 Usage {
987 input_tokens: u64::try_from(u.prompt_tokens).ok(),
988 output_tokens: u64::try_from(u.completion_tokens).ok(),
989 cache_read_input_tokens: u
990 .prompt_tokens_details
991 .as_ref()
992 .and_then(|d| d.cached_tokens)
993 .and_then(|v| u64::try_from(v).ok()),
994 cache_creation_input_tokens: None,
997 }
998}
999
1000#[cfg(test)]
1001mod tests;