1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6 AssistantContent, CohereCompletionRequest, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent};
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{Level, enabled, info_span};
15use tracing_futures::Instrument;
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "kebab-case", tag = "type")]
19enum StreamingEvent {
20 MessageStart,
21 ContentStart,
22 ContentDelta { delta: Option<Delta> },
23 ContentEnd,
24 ToolPlan,
25 ToolCallStart { delta: Option<Delta> },
26 ToolCallDelta { delta: Option<Delta> },
27 ToolCallEnd,
28 MessageEnd { delta: Option<MessageEndDelta> },
29}
30
31#[derive(Debug, Deserialize)]
32struct MessageContentDelta {
33 text: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolFunctionDelta {
38 name: Option<String>,
39 arguments: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageToolCallDelta {
44 id: Option<String>,
45 function: Option<MessageToolFunctionDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MessageDelta {
50 content: Option<MessageContentDelta>,
51 tool_calls: Option<MessageToolCallDelta>,
52}
53
54#[derive(Debug, Deserialize)]
55struct Delta {
56 message: Option<MessageDelta>,
57}
58
59#[derive(Debug, Deserialize)]
60struct MessageEndDelta {
61 usage: Option<Usage>,
62}
63
64#[derive(Clone, Serialize, Deserialize)]
65pub struct StreamingCompletionResponse {
66 pub usage: Option<Usage>,
67}
68
69impl GetTokenUsage for StreamingCompletionResponse {
70 fn token_usage(&self) -> Option<crate::completion::Usage> {
71 let tokens = self
72 .usage
73 .clone()
74 .and_then(|response| response.tokens)
75 .map(|tokens| {
76 (
77 tokens.input_tokens.map(|x| x as u64),
78 tokens.output_tokens.map(|y| y as u64),
79 )
80 });
81 let Some((Some(input), Some(output))) = tokens else {
82 return None;
83 };
84 let mut usage = crate::completion::Usage::new();
85 usage.input_tokens = input;
86 usage.output_tokens = output;
87 usage.total_tokens = input + output;
88
89 Some(usage)
90 }
91}
92
93impl<T> CompletionModel<T>
94where
95 T: HttpClientExt + Clone + 'static,
96{
97 pub(crate) async fn stream(
98 &self,
99 request: CompletionRequest,
100 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
101 {
102 let mut request = CohereCompletionRequest::try_from((self.model.as_ref(), request))?;
103 let span = if tracing::Span::current().is_disabled() {
104 info_span!(
105 target: "rig::completions",
106 "chat_streaming",
107 gen_ai.operation.name = "chat_streaming",
108 gen_ai.provider.name = "cohere",
109 gen_ai.request.model = self.model,
110 gen_ai.response.id = tracing::field::Empty,
111 gen_ai.response.model = self.model,
112 gen_ai.usage.output_tokens = tracing::field::Empty,
113 gen_ai.usage.input_tokens = tracing::field::Empty,
114 )
115 } else {
116 tracing::Span::current()
117 };
118
119 let params = json_utils::merge(
120 request.additional_params.unwrap_or(serde_json::json!({})),
121 serde_json::json!({"stream": true}),
122 );
123
124 request.additional_params = Some(params);
125
126 if enabled!(Level::TRACE) {
127 tracing::trace!(
128 target: "rig::streaming",
129 "Cohere streaming completion input: {}",
130 serde_json::to_string_pretty(&request)?
131 );
132 }
133
134 let body = serde_json::to_vec(&request)?;
135
136 let req = self.client.post("/v2/chat")?.body(body).unwrap();
137
138 let mut event_source = GenericEventSource::new(self.client.clone(), req);
139
140 let stream = stream! {
141 let mut current_tool_call: Option<(String, String, String, String)> = None;
142 let mut text_response = String::new();
143 let mut tool_calls = Vec::new();
144 let mut final_usage = None;
145
146 while let Some(event_result) = event_source.next().await {
147 match event_result {
148 Ok(Event::Open) => {
149 tracing::trace!("SSE connection opened");
150 continue;
151 }
152
153 Ok(Event::Message(message)) => {
154 let data_str = message.data.trim();
155 if data_str.is_empty() || data_str == "[DONE]" {
156 continue;
157 }
158
159 let event: StreamingEvent = match serde_json::from_str(data_str) {
160 Ok(ev) => ev,
161 Err(_) => {
162 tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
163 continue;
164 }
165 };
166
167 match event {
168 StreamingEvent::ContentDelta { delta: Some(delta) } => {
169 let Some(message) = &delta.message else { continue; };
170 let Some(content) = &message.content else { continue; };
171 let Some(text) = &content.text else { continue; };
172
173 text_response += text;
174
175 yield Ok(RawStreamingChoice::Message(text.clone()));
176 },
177
178 StreamingEvent::MessageEnd { delta: Some(delta) } => {
179 let message = Message::Assistant {
180 tool_calls: tool_calls.clone(),
181 content: vec![AssistantContent::Text { text: text_response.clone() }],
182 tool_plan: None,
183 citations: vec![]
184 };
185
186 let span = tracing::Span::current();
187 span.record_token_usage(&delta.usage);
188 span.record_model_output(&vec![message]);
189
190 final_usage = Some(delta.usage.clone());
191 break;
192 },
193
194 StreamingEvent::ToolCallStart { delta: Some(delta) } => {
195 let Some(message) = &delta.message else { continue; };
196 let Some(tool_calls) = &message.tool_calls else { continue; };
197 let Some(id) = tool_calls.id.clone() else { continue; };
198 let Some(function) = &tool_calls.function else { continue; };
199 let Some(name) = function.name.clone() else { continue; };
200 let Some(arguments) = function.arguments.clone() else { continue; };
201
202 let internal_call_id = nanoid::nanoid!();
203 current_tool_call = Some((id.clone(), internal_call_id.clone(), name.clone(), arguments));
204
205 yield Ok(RawStreamingChoice::ToolCallDelta {
206 id,
207 internal_call_id,
208 content: ToolCallDeltaContent::Name(name),
209 });
210 },
211
212 StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
213 let Some(message) = &delta.message else { continue; };
214 let Some(tool_calls) = &message.tool_calls else { continue; };
215 let Some(function) = &tool_calls.function else { continue; };
216 let Some(arguments) = function.arguments.clone() else { continue; };
217
218 let Some(tc) = current_tool_call.clone() else { continue; };
219 current_tool_call = Some((tc.0.clone(), tc.1.clone(), tc.2, format!("{}{}", tc.3, arguments)));
220
221 yield Ok(RawStreamingChoice::ToolCallDelta {
223 id: tc.0,
224 internal_call_id: tc.1,
225 content: ToolCallDeltaContent::Delta(arguments),
226 });
227 },
228
229 StreamingEvent::ToolCallEnd => {
230 let Some(tc) = current_tool_call.clone() else { continue; };
231 let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.3) else { continue; };
232
233 tool_calls.push(ToolCall {
234 id: Some(tc.0.clone()),
235 r#type: Some(ToolType::Function),
236 function: Some(ToolCallFunction {
237 name: tc.2.clone(),
238 arguments: args.clone()
239 })
240 });
241
242 let raw_tool_call = RawStreamingToolCall::new(tc.0, tc.2, args)
243 .with_internal_call_id(tc.1);
244 yield Ok(RawStreamingChoice::ToolCall(raw_tool_call));
245
246 current_tool_call = None;
247 },
248
249 _ => {}
250 }
251 },
252 Err(crate::http_client::Error::StreamEnded) => {
253 break;
254 }
255 Err(err) => {
256 tracing::error!(?err, "SSE error");
257 yield Err(CompletionError::ProviderError(err.to_string()));
258 break;
259 }
260 }
261 }
262
263 event_source.close();
265
266 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
267 usage: final_usage.unwrap_or_default()
268 }))
269 }.instrument(span);
270
271 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
272 stream,
273 )))
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use serde_json::json;
281
282 #[test]
283 fn test_message_content_delta_deserialization() {
284 let json = json!({
285 "type": "content-delta",
286 "delta": {
287 "message": {
288 "content": {
289 "text": "Hello world"
290 }
291 }
292 }
293 });
294
295 let event: StreamingEvent = serde_json::from_value(json).unwrap();
296 match event {
297 StreamingEvent::ContentDelta { delta } => {
298 assert!(delta.is_some());
299 let message = delta.unwrap().message.unwrap();
300 let content = message.content.unwrap();
301 assert_eq!(content.text, Some("Hello world".to_string()));
302 }
303 _ => panic!("Expected ContentDelta"),
304 }
305 }
306
307 #[test]
308 fn test_tool_call_start_deserialization() {
309 let json = json!({
310 "type": "tool-call-start",
311 "delta": {
312 "message": {
313 "tool_calls": {
314 "id": "call_123",
315 "function": {
316 "name": "get_weather",
317 "arguments": "{"
318 }
319 }
320 }
321 }
322 });
323
324 let event: StreamingEvent = serde_json::from_value(json).unwrap();
325 match event {
326 StreamingEvent::ToolCallStart { delta } => {
327 assert!(delta.is_some());
328 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
329 assert_eq!(tool_call.id, Some("call_123".to_string()));
330 assert_eq!(
331 tool_call.function.unwrap().name,
332 Some("get_weather".to_string())
333 );
334 }
335 _ => panic!("Expected ToolCallStart"),
336 }
337 }
338
339 #[test]
340 fn test_tool_call_delta_deserialization() {
341 let json = json!({
342 "type": "tool-call-delta",
343 "delta": {
344 "message": {
345 "tool_calls": {
346 "function": {
347 "arguments": "\"location\""
348 }
349 }
350 }
351 }
352 });
353
354 let event: StreamingEvent = serde_json::from_value(json).unwrap();
355 match event {
356 StreamingEvent::ToolCallDelta { delta } => {
357 assert!(delta.is_some());
358 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
359 let function = tool_call.function.unwrap();
360 assert_eq!(function.arguments, Some("\"location\"".to_string()));
361 }
362 _ => panic!("Expected ToolCallDelta"),
363 }
364 }
365
366 #[test]
367 fn test_tool_call_end_deserialization() {
368 let json = json!({
369 "type": "tool-call-end"
370 });
371
372 let event: StreamingEvent = serde_json::from_value(json).unwrap();
373 match event {
374 StreamingEvent::ToolCallEnd => {
375 }
377 _ => panic!("Expected ToolCallEnd"),
378 }
379 }
380
381 #[test]
382 fn test_message_end_with_usage_deserialization() {
383 let json = json!({
384 "type": "message-end",
385 "delta": {
386 "usage": {
387 "tokens": {
388 "input_tokens": 100,
389 "output_tokens": 50
390 }
391 }
392 }
393 });
394
395 let event: StreamingEvent = serde_json::from_value(json).unwrap();
396 match event {
397 StreamingEvent::MessageEnd { delta } => {
398 assert!(delta.is_some());
399 let usage = delta.unwrap().usage.unwrap();
400 let tokens = usage.tokens.unwrap();
401 assert_eq!(tokens.input_tokens, Some(100.0));
402 assert_eq!(tokens.output_tokens, Some(50.0));
403 }
404 _ => panic!("Expected MessageEnd"),
405 }
406 }
407
408 #[test]
409 fn test_streaming_event_order() {
410 let events = vec![
412 json!({"type": "message-start"}),
413 json!({"type": "content-start"}),
414 json!({
415 "type": "content-delta",
416 "delta": {
417 "message": {
418 "content": {
419 "text": "Sure, "
420 }
421 }
422 }
423 }),
424 json!({
425 "type": "content-delta",
426 "delta": {
427 "message": {
428 "content": {
429 "text": "I can help with that."
430 }
431 }
432 }
433 }),
434 json!({"type": "content-end"}),
435 json!({"type": "tool-plan"}),
436 json!({
437 "type": "tool-call-start",
438 "delta": {
439 "message": {
440 "tool_calls": {
441 "id": "call_abc",
442 "function": {
443 "name": "search",
444 "arguments": ""
445 }
446 }
447 }
448 }
449 }),
450 json!({
451 "type": "tool-call-delta",
452 "delta": {
453 "message": {
454 "tool_calls": {
455 "function": {
456 "arguments": "{\"query\":"
457 }
458 }
459 }
460 }
461 }),
462 json!({
463 "type": "tool-call-delta",
464 "delta": {
465 "message": {
466 "tool_calls": {
467 "function": {
468 "arguments": "\"Rust\"}"
469 }
470 }
471 }
472 }
473 }),
474 json!({"type": "tool-call-end"}),
475 json!({
476 "type": "message-end",
477 "delta": {
478 "usage": {
479 "tokens": {
480 "input_tokens": 50,
481 "output_tokens": 25
482 }
483 }
484 }
485 }),
486 ];
487
488 for (i, event_json) in events.iter().enumerate() {
489 let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
490 assert!(
491 result.is_ok(),
492 "Failed to deserialize event at index {}: {:?}",
493 i,
494 result.err()
495 );
496 }
497 }
498}