1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6 AssistantContent, CohereCompletionRequest, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent};
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{Level, enabled, info_span};
15use tracing_futures::Instrument;
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "kebab-case", tag = "type")]
19enum StreamingEvent {
20 MessageStart,
21 ContentStart,
22 ContentDelta { delta: Option<Delta> },
23 ContentEnd,
24 ToolPlan,
25 ToolCallStart { delta: Option<Delta> },
26 ToolCallDelta { delta: Option<Delta> },
27 ToolCallEnd,
28 MessageEnd { delta: Option<MessageEndDelta> },
29}
30
31#[derive(Debug, Deserialize)]
32struct MessageContentDelta {
33 text: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolFunctionDelta {
38 name: Option<String>,
39 arguments: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageToolCallDelta {
44 id: Option<String>,
45 function: Option<MessageToolFunctionDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MessageDelta {
50 content: Option<MessageContentDelta>,
51 tool_calls: Option<MessageToolCallDelta>,
52}
53
54#[derive(Debug, Deserialize)]
55struct Delta {
56 message: Option<MessageDelta>,
57}
58
59#[derive(Debug, Deserialize)]
60struct MessageEndDelta {
61 usage: Option<Usage>,
62}
63
64#[derive(Clone, Serialize, Deserialize)]
65pub struct StreamingCompletionResponse {
66 pub usage: Option<Usage>,
67}
68
69impl GetTokenUsage for StreamingCompletionResponse {
70 fn token_usage(&self) -> Option<crate::completion::Usage> {
71 let tokens = self
72 .usage
73 .clone()
74 .and_then(|response| response.tokens)
75 .map(|tokens| {
76 (
77 tokens.input_tokens.map(|x| x as u64),
78 tokens.output_tokens.map(|y| y as u64),
79 )
80 });
81 let Some((Some(input), Some(output))) = tokens else {
82 return None;
83 };
84 let mut usage = crate::completion::Usage::new();
85 usage.input_tokens = input;
86 usage.output_tokens = output;
87 usage.total_tokens = input + output;
88
89 Some(usage)
90 }
91}
92
93impl<T> CompletionModel<T>
94where
95 T: HttpClientExt + Clone + 'static,
96{
97 pub(crate) async fn stream(
98 &self,
99 request: CompletionRequest,
100 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
101 {
102 let mut request = CohereCompletionRequest::try_from((self.model.as_ref(), request))?;
103 let span = if tracing::Span::current().is_disabled() {
104 info_span!(
105 target: "rig::completions",
106 "chat_streaming",
107 gen_ai.operation.name = "chat_streaming",
108 gen_ai.provider.name = "cohere",
109 gen_ai.request.model = self.model,
110 gen_ai.response.id = tracing::field::Empty,
111 gen_ai.response.model = self.model,
112 gen_ai.usage.output_tokens = tracing::field::Empty,
113 gen_ai.usage.input_tokens = tracing::field::Empty,
114 gen_ai.usage.cached_tokens = tracing::field::Empty,
115 )
116 } else {
117 tracing::Span::current()
118 };
119
120 let params = json_utils::merge(
121 request.additional_params.unwrap_or(serde_json::json!({})),
122 serde_json::json!({"stream": true}),
123 );
124
125 request.additional_params = Some(params);
126
127 if enabled!(Level::TRACE) {
128 tracing::trace!(
129 target: "rig::streaming",
130 "Cohere streaming completion input: {}",
131 serde_json::to_string_pretty(&request)?
132 );
133 }
134
135 let body = serde_json::to_vec(&request)?;
136
137 let req = self.client.post("/v2/chat")?.body(body).unwrap();
138
139 let mut event_source = GenericEventSource::new(self.client.clone(), req);
140
141 let stream = stream! {
142 let mut current_tool_call: Option<(String, String, String, String)> = None;
143 let mut text_response = String::new();
144 let mut tool_calls = Vec::new();
145 let mut final_usage = None;
146
147 while let Some(event_result) = event_source.next().await {
148 match event_result {
149 Ok(Event::Open) => {
150 tracing::trace!("SSE connection opened");
151 continue;
152 }
153
154 Ok(Event::Message(message)) => {
155 let data_str = message.data.trim();
156 if data_str.is_empty() || data_str == "[DONE]" {
157 continue;
158 }
159
160 let event: StreamingEvent = match serde_json::from_str(data_str) {
161 Ok(ev) => ev,
162 Err(_) => {
163 tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
164 continue;
165 }
166 };
167
168 match event {
169 StreamingEvent::ContentDelta { delta: Some(delta) } => {
170 let Some(message) = &delta.message else { continue; };
171 let Some(content) = &message.content else { continue; };
172 let Some(text) = &content.text else { continue; };
173
174 text_response += text;
175
176 yield Ok(RawStreamingChoice::Message(text.clone()));
177 },
178
179 StreamingEvent::MessageEnd { delta: Some(delta) } => {
180 let message = Message::Assistant {
181 tool_calls: tool_calls.clone(),
182 content: vec![AssistantContent::Text { text: text_response.clone() }],
183 tool_plan: None,
184 citations: vec![]
185 };
186
187 let span = tracing::Span::current();
188 span.record_token_usage(&delta.usage);
189 span.record_model_output(&vec![message]);
190
191 final_usage = Some(delta.usage.clone());
192 break;
193 },
194
195 StreamingEvent::ToolCallStart { delta: Some(delta) } => {
196 let Some(message) = &delta.message else { continue; };
197 let Some(tool_calls) = &message.tool_calls else { continue; };
198 let Some(id) = tool_calls.id.clone() else { continue; };
199 let Some(function) = &tool_calls.function else { continue; };
200 let Some(name) = function.name.clone() else { continue; };
201 let Some(arguments) = function.arguments.clone() else { continue; };
202
203 let internal_call_id = nanoid::nanoid!();
204 current_tool_call = Some((id.clone(), internal_call_id.clone(), name.clone(), arguments));
205
206 yield Ok(RawStreamingChoice::ToolCallDelta {
207 id,
208 internal_call_id,
209 content: ToolCallDeltaContent::Name(name),
210 });
211 },
212
213 StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
214 let Some(message) = &delta.message else { continue; };
215 let Some(tool_calls) = &message.tool_calls else { continue; };
216 let Some(function) = &tool_calls.function else { continue; };
217 let Some(arguments) = function.arguments.clone() else { continue; };
218
219 let Some(tc) = current_tool_call.clone() else { continue; };
220 current_tool_call = Some((tc.0.clone(), tc.1.clone(), tc.2, format!("{}{}", tc.3, arguments)));
221
222 yield Ok(RawStreamingChoice::ToolCallDelta {
224 id: tc.0,
225 internal_call_id: tc.1,
226 content: ToolCallDeltaContent::Delta(arguments),
227 });
228 },
229
230 StreamingEvent::ToolCallEnd => {
231 let Some(tc) = current_tool_call.clone() else { continue; };
232 let Ok(args) = serde_json::from_str::<serde_json::Value>(&tc.3) else { continue; };
233
234 tool_calls.push(ToolCall {
235 id: Some(tc.0.clone()),
236 r#type: Some(ToolType::Function),
237 function: Some(ToolCallFunction {
238 name: tc.2.clone(),
239 arguments: args.clone()
240 })
241 });
242
243 let raw_tool_call = RawStreamingToolCall::new(tc.0, tc.2, args)
244 .with_internal_call_id(tc.1);
245 yield Ok(RawStreamingChoice::ToolCall(raw_tool_call));
246
247 current_tool_call = None;
248 },
249
250 _ => {}
251 }
252 },
253 Err(crate::http_client::Error::StreamEnded) => {
254 break;
255 }
256 Err(err) => {
257 tracing::error!(?err, "SSE error");
258 yield Err(CompletionError::ProviderError(err.to_string()));
259 break;
260 }
261 }
262 }
263
264 event_source.close();
266
267 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
268 usage: final_usage.unwrap_or_default()
269 }))
270 }.instrument(span);
271
272 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
273 stream,
274 )))
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use serde_json::json;
282
283 #[test]
284 fn test_message_content_delta_deserialization() {
285 let json = json!({
286 "type": "content-delta",
287 "delta": {
288 "message": {
289 "content": {
290 "text": "Hello world"
291 }
292 }
293 }
294 });
295
296 let event: StreamingEvent = serde_json::from_value(json).unwrap();
297 match event {
298 StreamingEvent::ContentDelta { delta } => {
299 assert!(delta.is_some());
300 let message = delta.unwrap().message.unwrap();
301 let content = message.content.unwrap();
302 assert_eq!(content.text, Some("Hello world".to_string()));
303 }
304 _ => panic!("Expected ContentDelta"),
305 }
306 }
307
308 #[test]
309 fn test_tool_call_start_deserialization() {
310 let json = json!({
311 "type": "tool-call-start",
312 "delta": {
313 "message": {
314 "tool_calls": {
315 "id": "call_123",
316 "function": {
317 "name": "get_weather",
318 "arguments": "{"
319 }
320 }
321 }
322 }
323 });
324
325 let event: StreamingEvent = serde_json::from_value(json).unwrap();
326 match event {
327 StreamingEvent::ToolCallStart { delta } => {
328 assert!(delta.is_some());
329 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
330 assert_eq!(tool_call.id, Some("call_123".to_string()));
331 assert_eq!(
332 tool_call.function.unwrap().name,
333 Some("get_weather".to_string())
334 );
335 }
336 _ => panic!("Expected ToolCallStart"),
337 }
338 }
339
340 #[test]
341 fn test_tool_call_delta_deserialization() {
342 let json = json!({
343 "type": "tool-call-delta",
344 "delta": {
345 "message": {
346 "tool_calls": {
347 "function": {
348 "arguments": "\"location\""
349 }
350 }
351 }
352 }
353 });
354
355 let event: StreamingEvent = serde_json::from_value(json).unwrap();
356 match event {
357 StreamingEvent::ToolCallDelta { delta } => {
358 assert!(delta.is_some());
359 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
360 let function = tool_call.function.unwrap();
361 assert_eq!(function.arguments, Some("\"location\"".to_string()));
362 }
363 _ => panic!("Expected ToolCallDelta"),
364 }
365 }
366
367 #[test]
368 fn test_tool_call_end_deserialization() {
369 let json = json!({
370 "type": "tool-call-end"
371 });
372
373 let event: StreamingEvent = serde_json::from_value(json).unwrap();
374 match event {
375 StreamingEvent::ToolCallEnd => {
376 }
378 _ => panic!("Expected ToolCallEnd"),
379 }
380 }
381
382 #[test]
383 fn test_message_end_with_usage_deserialization() {
384 let json = json!({
385 "type": "message-end",
386 "delta": {
387 "usage": {
388 "tokens": {
389 "input_tokens": 100,
390 "output_tokens": 50
391 }
392 }
393 }
394 });
395
396 let event: StreamingEvent = serde_json::from_value(json).unwrap();
397 match event {
398 StreamingEvent::MessageEnd { delta } => {
399 assert!(delta.is_some());
400 let usage = delta.unwrap().usage.unwrap();
401 let tokens = usage.tokens.unwrap();
402 assert_eq!(tokens.input_tokens, Some(100.0));
403 assert_eq!(tokens.output_tokens, Some(50.0));
404 }
405 _ => panic!("Expected MessageEnd"),
406 }
407 }
408
409 #[test]
410 fn test_streaming_event_order() {
411 let events = vec![
413 json!({"type": "message-start"}),
414 json!({"type": "content-start"}),
415 json!({
416 "type": "content-delta",
417 "delta": {
418 "message": {
419 "content": {
420 "text": "Sure, "
421 }
422 }
423 }
424 }),
425 json!({
426 "type": "content-delta",
427 "delta": {
428 "message": {
429 "content": {
430 "text": "I can help with that."
431 }
432 }
433 }
434 }),
435 json!({"type": "content-end"}),
436 json!({"type": "tool-plan"}),
437 json!({
438 "type": "tool-call-start",
439 "delta": {
440 "message": {
441 "tool_calls": {
442 "id": "call_abc",
443 "function": {
444 "name": "search",
445 "arguments": ""
446 }
447 }
448 }
449 }
450 }),
451 json!({
452 "type": "tool-call-delta",
453 "delta": {
454 "message": {
455 "tool_calls": {
456 "function": {
457 "arguments": "{\"query\":"
458 }
459 }
460 }
461 }
462 }),
463 json!({
464 "type": "tool-call-delta",
465 "delta": {
466 "message": {
467 "tool_calls": {
468 "function": {
469 "arguments": "\"Rust\"}"
470 }
471 }
472 }
473 }
474 }),
475 json!({"type": "tool-call-end"}),
476 json!({
477 "type": "message-end",
478 "delta": {
479 "usage": {
480 "tokens": {
481 "input_tokens": 50,
482 "output_tokens": 25
483 }
484 }
485 }
486 }),
487 ];
488
489 for (i, event_json) in events.iter().enumerate() {
490 let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
491 assert!(
492 result.is_ok(),
493 "Failed to deserialize event at index {}: {:?}",
494 i,
495 result.err()
496 );
497 }
498 }
499}