Skip to main content

rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6    AssistantContent, CohereCompletionRequest, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent};
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{Level, enabled, info_span};
15use tracing_futures::Instrument;
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "kebab-case", tag = "type")]
19enum StreamingEvent {
20    MessageStart,
21    ContentStart,
22    ContentDelta { delta: Option<Delta> },
23    ContentEnd,
24    ToolPlan,
25    ToolCallStart { delta: Option<Delta> },
26    ToolCallDelta { delta: Option<Delta> },
27    ToolCallEnd,
28    MessageEnd { delta: Option<MessageEndDelta> },
29}
30
31#[derive(Debug, Deserialize)]
32struct MessageContentDelta {
33    text: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolFunctionDelta {
38    name: Option<String>,
39    arguments: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageToolCallDelta {
44    id: Option<String>,
45    function: Option<MessageToolFunctionDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MessageDelta {
50    content: Option<MessageContentDelta>,
51    tool_calls: Option<MessageToolCallDelta>,
52}
53
54#[derive(Debug, Deserialize)]
55struct Delta {
56    message: Option<MessageDelta>,
57}
58
59#[derive(Debug, Deserialize)]
60struct MessageEndDelta {
61    usage: Option<Usage>,
62}
63
64#[derive(Clone, Serialize, Deserialize)]
65pub struct StreamingCompletionResponse {
66    pub usage: Option<Usage>,
67}
68
69impl GetTokenUsage for StreamingCompletionResponse {
70    fn token_usage(&self) -> Option<crate::completion::Usage> {
71        let tokens = self
72            .usage
73            .clone()
74            .and_then(|response| response.tokens)
75            .map(|tokens| {
76                (
77                    tokens.input_tokens.map(|x| x as u64),
78                    tokens.output_tokens.map(|y| y as u64),
79                )
80            });
81        let Some((Some(input), Some(output))) = tokens else {
82            return None;
83        };
84        let mut usage = crate::completion::Usage::new();
85        usage.input_tokens = input;
86        usage.output_tokens = output;
87        usage.total_tokens = input + output;
88
89        Some(usage)
90    }
91}
92
93impl<T> CompletionModel<T>
94where
95    T: HttpClientExt + Clone + 'static,
96{
97    pub(crate) async fn stream(
98        &self,
99        request: CompletionRequest,
100    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
101    {
102        let mut request = CohereCompletionRequest::try_from((self.model.as_ref(), request))?;
103        let span = if tracing::Span::current().is_disabled() {
104            info_span!(
105                target: "rig::completions",
106                "chat_streaming",
107                gen_ai.operation.name = "chat_streaming",
108                gen_ai.provider.name = "cohere",
109                gen_ai.request.model = self.model,
110                gen_ai.response.id = tracing::field::Empty,
111                gen_ai.response.model = self.model,
112                gen_ai.usage.output_tokens = tracing::field::Empty,
113                gen_ai.usage.input_tokens = tracing::field::Empty,
114                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
115            )
116        } else {
117            tracing::Span::current()
118        };
119
120        let params = json_utils::merge(
121            request.additional_params.unwrap_or(serde_json::json!({})),
122            serde_json::json!({"stream": true}),
123        );
124
125        request.additional_params = Some(params);
126
127        if enabled!(Level::TRACE) {
128            tracing::trace!(
129                target: "rig::streaming",
130                "Cohere streaming completion input: {}",
131                serde_json::to_string_pretty(&request)?
132            );
133        }
134
135        let body = serde_json::to_vec(&request)?;
136
137        let req = self
138            .client
139            .post("/v2/chat")?
140            .body(body)
141            .map_err(|e| CompletionError::HttpError(e.into()))?;
142
143        let mut event_source = GenericEventSource::new(self.client.clone(), req);
144
145        let stream = stream! {
146            let mut current_tool_call: Option<(String, String, String, String)> = None;
147            let mut text_response = String::new();
148            let mut tool_calls = Vec::new();
149            let mut final_usage = None;
150
151            while let Some(event_result) = event_source.next().await {
152                match event_result {
153                    Ok(Event::Open) => {
154                        tracing::trace!("SSE connection opened");
155                        continue;
156                    }
157
158                    Ok(Event::Message(message)) => {
159                        let data_str = message.data.trim();
160                        if data_str.is_empty() || data_str == "[DONE]" {
161                            continue;
162                        }
163
164                        let event: StreamingEvent = match serde_json::from_str(data_str) {
165                            Ok(ev) => ev,
166                            Err(_) => {
167                                tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
168                                continue;
169                            }
170                        };
171
172                        match event {
173                            StreamingEvent::ContentDelta { delta: Some(delta) } => {
174                                let Some(message) = &delta.message else { continue; };
175                                let Some(content) = &message.content else { continue; };
176                                let Some(text) = &content.text else { continue; };
177
178                                text_response += text;
179
180                                yield Ok(RawStreamingChoice::Message(text.clone()));
181                            },
182
183                            StreamingEvent::MessageEnd { delta: Some(delta) } => {
184                                let message = Message::Assistant {
185                                    tool_calls: tool_calls.clone(),
186                                    content: vec![AssistantContent::Text { text: text_response.clone() }],
187                                    tool_plan: None,
188                                    citations: vec![]
189                                };
190
191                                let span = tracing::Span::current();
192                                span.record_token_usage(&delta.usage);
193                                span.record_model_output(&vec![message]);
194
195                                final_usage = Some(delta.usage.clone());
196                                break;
197                            },
198
199                            StreamingEvent::ToolCallStart { delta: Some(delta) } => {
200                                let Some(message) = &delta.message else { continue; };
201                                let Some(tool_calls) = &message.tool_calls else { continue; };
202                                let Some(id) = tool_calls.id.clone() else { continue; };
203                                let Some(function) = &tool_calls.function else { continue; };
204                                let Some(name) = function.name.clone() else { continue; };
205                                let Some(arguments) = function.arguments.clone() else { continue; };
206
207                                let internal_call_id = nanoid::nanoid!();
208                                current_tool_call = Some((id.clone(), internal_call_id.clone(), name.clone(), arguments));
209
210                                yield Ok(RawStreamingChoice::ToolCallDelta {
211                                    id,
212                                    internal_call_id,
213                                    content: ToolCallDeltaContent::Name(name),
214                                });
215                            },
216
217                            StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
218                                let Some(message) = &delta.message else { continue; };
219                                let Some(tool_calls) = &message.tool_calls else { continue; };
220                                let Some(function) = &tool_calls.function else { continue; };
221                                let Some(arguments) = function.arguments.clone() else { continue; };
222
223                                let Some(tc) = current_tool_call.clone() else { continue; };
224                                current_tool_call = Some((tc.0.clone(), tc.1.clone(), tc.2, format!("{}{}", tc.3, arguments)));
225
226                                // Emit the delta so UI can show progress
227                                yield Ok(RawStreamingChoice::ToolCallDelta {
228                                    id: tc.0,
229                                    internal_call_id: tc.1,
230                                    content: ToolCallDeltaContent::Delta(arguments),
231                                });
232                            },
233
234                            StreamingEvent::ToolCallEnd => {
235                                let Some(tc) = current_tool_call.clone() else { continue; };
236                                let Ok(args) = json_utils::parse_tool_arguments(&tc.3) else { continue; };
237
238                                tool_calls.push(ToolCall {
239                                    id: Some(tc.0.clone()),
240                                    r#type: Some(ToolType::Function),
241                                    function: Some(ToolCallFunction {
242                                        name: tc.2.clone(),
243                                        arguments: args.clone()
244                                    })
245                                });
246
247                                let raw_tool_call = RawStreamingToolCall::new(tc.0, tc.2, args)
248                                    .with_internal_call_id(tc.1);
249                                yield Ok(RawStreamingChoice::ToolCall(raw_tool_call));
250
251                                current_tool_call = None;
252                            },
253
254                            _ => {}
255                        }
256                    },
257                    Err(crate::http_client::Error::StreamEnded) => {
258                        break;
259                    }
260                    Err(err) => {
261                        tracing::error!(?err, "SSE error");
262                        yield Err(CompletionError::ProviderError(err.to_string()));
263                        break;
264                    }
265                }
266            }
267
268            // Ensure event source is closed when stream ends
269            event_source.close();
270
271            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
272                usage: final_usage.unwrap_or_default()
273            }))
274        }.instrument(span);
275
276        Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
277            stream,
278        )))
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use serde_json::json;
286
287    #[test]
288    fn test_message_content_delta_deserialization() {
289        let json = json!({
290            "type": "content-delta",
291            "delta": {
292                "message": {
293                    "content": {
294                        "text": "Hello world"
295                    }
296                }
297            }
298        });
299
300        let event: StreamingEvent = serde_json::from_value(json).unwrap();
301        match event {
302            StreamingEvent::ContentDelta { delta } => {
303                assert!(delta.is_some());
304                let message = delta.unwrap().message.unwrap();
305                let content = message.content.unwrap();
306                assert_eq!(content.text, Some("Hello world".to_string()));
307            }
308            _ => panic!("Expected ContentDelta"),
309        }
310    }
311
312    #[test]
313    fn test_tool_call_start_deserialization() {
314        let json = json!({
315            "type": "tool-call-start",
316            "delta": {
317                "message": {
318                    "tool_calls": {
319                        "id": "call_123",
320                        "function": {
321                            "name": "get_weather",
322                            "arguments": "{"
323                        }
324                    }
325                }
326            }
327        });
328
329        let event: StreamingEvent = serde_json::from_value(json).unwrap();
330        match event {
331            StreamingEvent::ToolCallStart { delta } => {
332                assert!(delta.is_some());
333                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
334                assert_eq!(tool_call.id, Some("call_123".to_string()));
335                assert_eq!(
336                    tool_call.function.unwrap().name,
337                    Some("get_weather".to_string())
338                );
339            }
340            _ => panic!("Expected ToolCallStart"),
341        }
342    }
343
344    #[test]
345    fn test_tool_call_delta_deserialization() {
346        let json = json!({
347            "type": "tool-call-delta",
348            "delta": {
349                "message": {
350                    "tool_calls": {
351                        "function": {
352                            "arguments": "\"location\""
353                        }
354                    }
355                }
356            }
357        });
358
359        let event: StreamingEvent = serde_json::from_value(json).unwrap();
360        match event {
361            StreamingEvent::ToolCallDelta { delta } => {
362                assert!(delta.is_some());
363                let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
364                let function = tool_call.function.unwrap();
365                assert_eq!(function.arguments, Some("\"location\"".to_string()));
366            }
367            _ => panic!("Expected ToolCallDelta"),
368        }
369    }
370
371    #[test]
372    fn test_tool_call_end_deserialization() {
373        let json = json!({
374            "type": "tool-call-end"
375        });
376
377        let event: StreamingEvent = serde_json::from_value(json).unwrap();
378        match event {
379            StreamingEvent::ToolCallEnd => {
380                // Success
381            }
382            _ => panic!("Expected ToolCallEnd"),
383        }
384    }
385
386    #[test]
387    fn test_message_end_with_usage_deserialization() {
388        let json = json!({
389            "type": "message-end",
390            "delta": {
391                "usage": {
392                    "tokens": {
393                        "input_tokens": 100,
394                        "output_tokens": 50
395                    }
396                }
397            }
398        });
399
400        let event: StreamingEvent = serde_json::from_value(json).unwrap();
401        match event {
402            StreamingEvent::MessageEnd { delta } => {
403                assert!(delta.is_some());
404                let usage = delta.unwrap().usage.unwrap();
405                let tokens = usage.tokens.unwrap();
406                assert_eq!(tokens.input_tokens, Some(100.0));
407                assert_eq!(tokens.output_tokens, Some(50.0));
408            }
409            _ => panic!("Expected MessageEnd"),
410        }
411    }
412
413    #[test]
414    fn test_streaming_event_order() {
415        // Test that a typical sequence of events deserializes correctly
416        let events = vec![
417            json!({"type": "message-start"}),
418            json!({"type": "content-start"}),
419            json!({
420                "type": "content-delta",
421                "delta": {
422                    "message": {
423                        "content": {
424                            "text": "Sure, "
425                        }
426                    }
427                }
428            }),
429            json!({
430                "type": "content-delta",
431                "delta": {
432                    "message": {
433                        "content": {
434                            "text": "I can help with that."
435                        }
436                    }
437                }
438            }),
439            json!({"type": "content-end"}),
440            json!({"type": "tool-plan"}),
441            json!({
442                "type": "tool-call-start",
443                "delta": {
444                    "message": {
445                        "tool_calls": {
446                            "id": "call_abc",
447                            "function": {
448                                "name": "search",
449                                "arguments": ""
450                            }
451                        }
452                    }
453                }
454            }),
455            json!({
456                "type": "tool-call-delta",
457                "delta": {
458                    "message": {
459                        "tool_calls": {
460                            "function": {
461                                "arguments": "{\"query\":"
462                            }
463                        }
464                    }
465                }
466            }),
467            json!({
468                "type": "tool-call-delta",
469                "delta": {
470                    "message": {
471                        "tool_calls": {
472                            "function": {
473                                "arguments": "\"Rust\"}"
474                            }
475                        }
476                    }
477                }
478            }),
479            json!({"type": "tool-call-end"}),
480            json!({
481                "type": "message-end",
482                "delta": {
483                    "usage": {
484                        "tokens": {
485                            "input_tokens": 50,
486                            "output_tokens": 25
487                        }
488                    }
489                }
490            }),
491        ];
492
493        for (i, event_json) in events.iter().enumerate() {
494            let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
495            assert!(
496                result.is_ok(),
497                "Failed to deserialize event at index {}: {:?}",
498                i,
499                result.err()
500            );
501        }
502    }
503}