rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6    AssistantContent, CohereCompletionRequest, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent};
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{Level, enabled, info_span};
15use tracing_futures::Instrument;
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "kebab-case", tag = "type")]
19enum StreamingEvent {
20    MessageStart,
21    ContentStart,
22    ContentDelta { delta: Option<Delta> },
23    ContentEnd,
24    ToolPlan,
25    ToolCallStart { delta: Option<Delta> },
26    ToolCallDelta { delta: Option<Delta> },
27    ToolCallEnd,
28    MessageEnd { delta: Option<MessageEndDelta> },
29}
30
31#[derive(Debug, Deserialize)]
32struct MessageContentDelta {
33    text: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolFunctionDelta {
38    name: Option<String>,
39    arguments: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageToolCallDelta {
44    id: Option<String>,
45    function: Option<MessageToolFunctionDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MessageDelta {
50    content: Option<MessageContentDelta>,
51    tool_calls: Option<MessageToolCallDelta>,
52}
53
54#[derive(Debug, Deserialize)]
55struct Delta {
56    message: Option<MessageDelta>,
57}
58
59#[derive(Debug, Deserialize)]
60struct MessageEndDelta {
61    usage: Option<Usage>,
62}
63
64#[derive(Clone, Serialize, Deserialize)]
65pub struct StreamingCompletionResponse {
66    pub usage: Option<Usage>,
67}
68
69impl GetTokenUsage for StreamingCompletionResponse {
70    fn token_usage(&self) -> Option<crate::completion::Usage> {
71        let tokens = self
72            .usage
73            .clone()
74            .and_then(|response| response.tokens)
75            .map(|tokens| {
76                (
77                    tokens.input_tokens.map(|x| x as u64),
78                    tokens.output_tokens.map(|y| y as u64),
79                )
80            });
81        let Some((Some(input), Some(output))) = tokens else {
82            return None;
83        };
84        let mut usage = crate::completion::Usage::new();
85        usage.input_tokens = input;
86        usage.output_tokens = output;
87        usage.total_tokens = input + output;
88
89        Some(usage)
90    }
91}
92
93impl<T> CompletionModel<T>
94where
95    T: HttpClientExt + Clone + 'static,
96{
97    pub(crate) async fn stream(
98        &self,
99        request: CompletionRequest,
100    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
101    {
102        let mut request = CohereCompletionRequest::try_from((self.model.as_ref(), request))?;
103        let span = if tracing::Span::current().is_disabled() {
104            info_span!(
105                target: "rig::completions",
106                "chat_streaming",
107                gen_ai.operation.name = "chat_streaming",
108                gen_ai.provider.name = "cohere",
109                gen_ai.request.model = self.model,
110                gen_ai.response.id = tracing::field::Empty,
111                gen_ai.response.model = self.model,
112                gen_ai.usage.output_tokens = tracing::field::Empty,
113                gen_ai.usage.input_tokens = tracing::field::Empty,
114            )
115        } else {
116            tracing::Span::current()
117        };
118
119        let params = json_utils::merge(
120            request.additional_params.unwrap_or(serde_json::json!({})),
121            serde_json::json!({"stream": true}),
122        );
123
124        request.additional_params = Some(params);
125
126        if enabled!(Level::TRACE) {
127            tracing::trace!(
128                target: "rig::streaming",
129                "Cohere streaming completion input: {}",
130                serde_json::to_string_pretty(&request)?
131            );
132        }
133
134        let body = serde_json::to_vec(&request)?;
135
136        let req = self.client.post("/v2/chat")?.body(body).unwrap();
137
138        let mut event_source = GenericEventSource::new(self.client.clone(), req);
139
140        let stream = stream! {
141            let mut current_tool_call: Option<(String, String, String)> = None;
142            let mut text_response = String::new();
143            let mut tool_calls = Vec::new();
144            let mut final_usage = None;
145
146            while let Some(event_result) = event_source.next().await {
147                match event_result {
148                    Ok(Event::Open) => {
149                        tracing::trace!("SSE connection opened");
150                        continue;
151                    }
152
153                    Ok(Event::Message(message)) => {
154                        let data_str = message.data.trim();
155                        if data_str.is_empty() || data_str == "[DONE]" {
156                            continue;
157                        }
158
159                        let event: StreamingEvent = match serde_json::from_str(data_str) {
160                            Ok(ev) => ev,
161                            Err(_) => {
162                                tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
163                                continue;
164                            }
165                        };
166
167                        match event {
168                            StreamingEvent::ContentDelta { delta: Some(delta) } => {
169                                let Some(message) = &delta.message else { continue; };
170                                let Some(content) = &message.content else { continue; };
171                                let Some(text) = &content.text else { continue; };
172
173                                text_response += text;
174
175                                yield Ok(RawStreamingChoice::Message(text.clone()));
176                            },
177
178                            StreamingEvent::MessageEnd { delta: Some(delta) } => {
179                                let message = Message::Assistant {
180                                    tool_calls: tool_calls.clone(),
181                                    content: vec![AssistantContent::Text { text: text_response.clone() }],
182                                    tool_plan: None,
183                                    citations: vec![]
184                                };
185
186                                let span = tracing::Span::current();
187                                span.record_token_usage(&delta.usage);
188                                span.record_model_output(&vec![message]);
189
190                                final_usage = Some(delta.usage.clone());
191                                break;
192                            },
193
194                            StreamingEvent::ToolCallStart { delta: Some(delta) } => {
195                                let Some(message) = &delta.message else { continue; };
196                                let Some(tool_calls) = &message.tool_calls else { continue; };
197                                let Some(id) = tool_calls.id.clone() else { continue; };
198                                let Some(function) = &tool_calls.function else { continue; };
199                                let Some(name) = function.name.clone() else { continue; };
200                                let Some(arguments) = function.arguments.clone() else { continue; };
201
202                                current_tool_call = Some((id.clone(), name.clone(), arguments));
203
204                                yield Ok(RawStreamingChoice::ToolCallDelta {
205                                    id,
206                                    content: ToolCallDeltaContent::Name(name),
207                                });
208                            },
209
210                            StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
211                                let Some(message) = &delta.message else { continue; };
212                                let Some(tool_calls) = &message.tool_calls else { continue; };
213                                let Some(function) = &tool_calls.function else { continue; };
214                                let Some(arguments) = function.arguments.clone() else { continue; };
215
216                                let Some(tc) = current_tool_call.clone() else { continue; };
217                                current_tool_call = Some((tc.0.clone(), tc.1, format!("{}{}", tc.2, arguments)));
218
219                                // Emit the delta so UI can show progress
220                                yield Ok(RawStreamingChoice::ToolCallDelta {
221                                    id: tc.0,
222                                    content: ToolCallDeltaContent::Delta(arguments),
223                                });
224                            },
225
226                            StreamingEvent::ToolCallEnd => {
227                                let Some(tc) = current_tool_call.clone() else { continue; };
228                                let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.2) else { continue; };
229
230                                tool_calls.push(ToolCall {
231                                    id: Some(tc.0.clone()),
232                                    r#type: Some(ToolType::Function),
233                                    function: Some(ToolCallFunction {
234                                        name: tc.1.clone(),
235                                        arguments: args.clone()
236                                    })
237                                });
238
239                                yield Ok(RawStreamingChoice::ToolCall(
240                                    RawStreamingToolCall::new(tc.0, tc.1, args)
241                                ));
242
243                                current_tool_call = None;
244                            },
245
246                            _ => {}
247                        }
248                    },
249                    Err(crate::http_client::Error::StreamEnded) => {
250                        break;
251                    }
252                    Err(err) => {
253                        tracing::error!(?err, "SSE error");
254                        yield Err(CompletionError::ProviderError(err.to_string()));
255                        break;
256                    }
257                }
258            }
259
260            // Ensure event source is closed when stream ends
261            event_source.close();
262
263            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
264                usage: final_usage.unwrap_or_default()
265            }))
266        }.instrument(span);
267
268        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
269            stream,
270        )))
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use serde_json::json;
278
279    #[test]
280    fn test_message_content_delta_deserialization() {
281        let json = json!({
282            "type": "content-delta",
283            "delta": {
284                "message": {
285                    "content": {
286                        "text": "Hello world"
287                    }
288                }
289            }
290        });
291
292        let event: StreamingEvent = serde_json::from_value(json).unwrap();
293        match event {
294            StreamingEvent::ContentDelta { delta } => {
295                assert!(delta.is_some());
296                let message = delta.unwrap().message.unwrap();
297                let content = message.content.unwrap();
298                assert_eq!(content.text, Some("Hello world".to_string()));
299            }
300            _ => panic!("Expected ContentDelta"),
301        }
302    }
303
304    #[test]
305    fn test_tool_call_start_deserialization() {
306        let json = json!({
307            "type": "tool-call-start",
308            "delta": {
309                "message": {
310                    "tool_calls": {
311                        "id": "call_123",
312                        "function": {
313                            "name": "get_weather",
314                            "arguments": "{"
315                        }
316                    }
317                }
318            }
319        });
320
321        let event: StreamingEvent = serde_json::from_value(json).unwrap();
322        match event {
323            StreamingEvent::ToolCallStart { delta } => {
324                assert!(delta.is_some());
325                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
326                assert_eq!(tool_call.id, Some("call_123".to_string()));
327                assert_eq!(
328                    tool_call.function.unwrap().name,
329                    Some("get_weather".to_string())
330                );
331            }
332            _ => panic!("Expected ToolCallStart"),
333        }
334    }
335
336    #[test]
337    fn test_tool_call_delta_deserialization() {
338        let json = json!({
339            "type": "tool-call-delta",
340            "delta": {
341                "message": {
342                    "tool_calls": {
343                        "function": {
344                            "arguments": "\"location\""
345                        }
346                    }
347                }
348            }
349        });
350
351        let event: StreamingEvent = serde_json::from_value(json).unwrap();
352        match event {
353            StreamingEvent::ToolCallDelta { delta } => {
354                assert!(delta.is_some());
355                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
356                let function = tool_call.function.unwrap();
357                assert_eq!(function.arguments, Some("\"location\"".to_string()));
358            }
359            _ => panic!("Expected ToolCallDelta"),
360        }
361    }
362
363    #[test]
364    fn test_tool_call_end_deserialization() {
365        let json = json!({
366            "type": "tool-call-end"
367        });
368
369        let event: StreamingEvent = serde_json::from_value(json).unwrap();
370        match event {
371            StreamingEvent::ToolCallEnd => {
372                // Success
373            }
374            _ => panic!("Expected ToolCallEnd"),
375        }
376    }
377
378    #[test]
379    fn test_message_end_with_usage_deserialization() {
380        let json = json!({
381            "type": "message-end",
382            "delta": {
383                "usage": {
384                    "tokens": {
385                        "input_tokens": 100,
386                        "output_tokens": 50
387                    }
388                }
389            }
390        });
391
392        let event: StreamingEvent = serde_json::from_value(json).unwrap();
393        match event {
394            StreamingEvent::MessageEnd { delta } => {
395                assert!(delta.is_some());
396                let usage = delta.unwrap().usage.unwrap();
397                let tokens = usage.tokens.unwrap();
398                assert_eq!(tokens.input_tokens, Some(100.0));
399                assert_eq!(tokens.output_tokens, Some(50.0));
400            }
401            _ => panic!("Expected MessageEnd"),
402        }
403    }
404
405    #[test]
406    fn test_streaming_event_order() {
407        // Test that a typical sequence of events deserializes correctly
408        let events = vec![
409            json!({"type": "message-start"}),
410            json!({"type": "content-start"}),
411            json!({
412                "type": "content-delta",
413                "delta": {
414                    "message": {
415                        "content": {
416                            "text": "Sure, "
417                        }
418                    }
419                }
420            }),
421            json!({
422                "type": "content-delta",
423                "delta": {
424                    "message": {
425                        "content": {
426                            "text": "I can help with that."
427                        }
428                    }
429                }
430            }),
431            json!({"type": "content-end"}),
432            json!({"type": "tool-plan"}),
433            json!({
434                "type": "tool-call-start",
435                "delta": {
436                    "message": {
437                        "tool_calls": {
438                            "id": "call_abc",
439                            "function": {
440                                "name": "search",
441                                "arguments": ""
442                            }
443                        }
444                    }
445                }
446            }),
447            json!({
448                "type": "tool-call-delta",
449                "delta": {
450                    "message": {
451                        "tool_calls": {
452                            "function": {
453                                "arguments": "{\"query\":"
454                            }
455                        }
456                    }
457                }
458            }),
459            json!({
460                "type": "tool-call-delta",
461                "delta": {
462                    "message": {
463                        "tool_calls": {
464                            "function": {
465                                "arguments": "\"Rust\"}"
466                            }
467                        }
468                    }
469                }
470            }),
471            json!({"type": "tool-call-end"}),
472            json!({
473                "type": "message-end",
474                "delta": {
475                    "usage": {
476                        "tokens": {
477                            "input_tokens": 50,
478                            "output_tokens": 25
479                        }
480                    }
481                }
482            }),
483        ];
484
485        for (i, event_json) in events.iter().enumerate() {
486            let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
487            assert!(
488                result.is_ok(),
489                "Failed to deserialize event at index {}: {:?}",
490                i,
491                result.err()
492            );
493        }
494    }
495}