1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::http_client::HttpClientExt;
3use crate::http_client::sse::{Event, GenericEventSource};
4use crate::providers::cohere::CompletionModel;
5use crate::providers::cohere::completion::{
6 AssistantContent, CohereCompletionRequest, Message, ToolCall, ToolCallFunction, ToolType, Usage,
7};
8use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent};
9use crate::telemetry::SpanCombinator;
10use crate::{json_utils, streaming};
11use async_stream::stream;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use tracing::{Level, enabled, info_span};
15use tracing_futures::Instrument;
16
17#[derive(Debug, Deserialize)]
18#[serde(rename_all = "kebab-case", tag = "type")]
19enum StreamingEvent {
20 MessageStart,
21 ContentStart,
22 ContentDelta { delta: Option<Delta> },
23 ContentEnd,
24 ToolPlan,
25 ToolCallStart { delta: Option<Delta> },
26 ToolCallDelta { delta: Option<Delta> },
27 ToolCallEnd,
28 MessageEnd { delta: Option<MessageEndDelta> },
29}
30
31#[derive(Debug, Deserialize)]
32struct MessageContentDelta {
33 text: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolFunctionDelta {
38 name: Option<String>,
39 arguments: Option<String>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageToolCallDelta {
44 id: Option<String>,
45 function: Option<MessageToolFunctionDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MessageDelta {
50 content: Option<MessageContentDelta>,
51 tool_calls: Option<MessageToolCallDelta>,
52}
53
54#[derive(Debug, Deserialize)]
55struct Delta {
56 message: Option<MessageDelta>,
57}
58
59#[derive(Debug, Deserialize)]
60struct MessageEndDelta {
61 usage: Option<Usage>,
62}
63
64#[derive(Clone, Serialize, Deserialize)]
65pub struct StreamingCompletionResponse {
66 pub usage: Option<Usage>,
67}
68
69impl GetTokenUsage for StreamingCompletionResponse {
70 fn token_usage(&self) -> Option<crate::completion::Usage> {
71 let tokens = self
72 .usage
73 .clone()
74 .and_then(|response| response.tokens)
75 .map(|tokens| {
76 (
77 tokens.input_tokens.map(|x| x as u64),
78 tokens.output_tokens.map(|y| y as u64),
79 )
80 });
81 let Some((Some(input), Some(output))) = tokens else {
82 return None;
83 };
84 let mut usage = crate::completion::Usage::new();
85 usage.input_tokens = input;
86 usage.output_tokens = output;
87 usage.total_tokens = input + output;
88
89 Some(usage)
90 }
91}
92
93impl<T> CompletionModel<T>
94where
95 T: HttpClientExt + Clone + 'static,
96{
97 pub(crate) async fn stream(
98 &self,
99 request: CompletionRequest,
100 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
101 {
102 let mut request = CohereCompletionRequest::try_from((self.model.as_ref(), request))?;
103 let span = if tracing::Span::current().is_disabled() {
104 info_span!(
105 target: "rig::completions",
106 "chat_streaming",
107 gen_ai.operation.name = "chat_streaming",
108 gen_ai.provider.name = "cohere",
109 gen_ai.request.model = self.model,
110 gen_ai.response.id = tracing::field::Empty,
111 gen_ai.response.model = self.model,
112 gen_ai.usage.output_tokens = tracing::field::Empty,
113 gen_ai.usage.input_tokens = tracing::field::Empty,
114 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
115 )
116 } else {
117 tracing::Span::current()
118 };
119
120 let params = json_utils::merge(
121 request.additional_params.unwrap_or(serde_json::json!({})),
122 serde_json::json!({"stream": true}),
123 );
124
125 request.additional_params = Some(params);
126
127 if enabled!(Level::TRACE) {
128 tracing::trace!(
129 target: "rig::streaming",
130 "Cohere streaming completion input: {}",
131 serde_json::to_string_pretty(&request)?
132 );
133 }
134
135 let body = serde_json::to_vec(&request)?;
136
137 let req = self
138 .client
139 .post("/v2/chat")?
140 .body(body)
141 .map_err(|e| CompletionError::HttpError(e.into()))?;
142
143 let mut event_source = GenericEventSource::new(self.client.clone(), req);
144
145 let stream = stream! {
146 let mut current_tool_call: Option<(String, String, String, String)> = None;
147 let mut text_response = String::new();
148 let mut tool_calls = Vec::new();
149 let mut final_usage = None;
150
151 while let Some(event_result) = event_source.next().await {
152 match event_result {
153 Ok(Event::Open) => {
154 tracing::trace!("SSE connection opened");
155 continue;
156 }
157
158 Ok(Event::Message(message)) => {
159 let data_str = message.data.trim();
160 if data_str.is_empty() || data_str == "[DONE]" {
161 continue;
162 }
163
164 let event: StreamingEvent = match serde_json::from_str(data_str) {
165 Ok(ev) => ev,
166 Err(_) => {
167 tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
168 continue;
169 }
170 };
171
172 match event {
173 StreamingEvent::ContentDelta { delta: Some(delta) } => {
174 let Some(message) = &delta.message else { continue; };
175 let Some(content) = &message.content else { continue; };
176 let Some(text) = &content.text else { continue; };
177
178 text_response += text;
179
180 yield Ok(RawStreamingChoice::Message(text.clone()));
181 },
182
183 StreamingEvent::MessageEnd { delta: Some(delta) } => {
184 let message = Message::Assistant {
185 tool_calls: tool_calls.clone(),
186 content: vec![AssistantContent::Text { text: text_response.clone() }],
187 tool_plan: None,
188 citations: vec![]
189 };
190
191 let span = tracing::Span::current();
192 span.record_token_usage(&delta.usage);
193 span.record_model_output(&vec![message]);
194
195 final_usage = Some(delta.usage.clone());
196 break;
197 },
198
199 StreamingEvent::ToolCallStart { delta: Some(delta) } => {
200 let Some(message) = &delta.message else { continue; };
201 let Some(tool_calls) = &message.tool_calls else { continue; };
202 let Some(id) = tool_calls.id.clone() else { continue; };
203 let Some(function) = &tool_calls.function else { continue; };
204 let Some(name) = function.name.clone() else { continue; };
205 let Some(arguments) = function.arguments.clone() else { continue; };
206
207 let internal_call_id = nanoid::nanoid!();
208 current_tool_call = Some((id.clone(), internal_call_id.clone(), name.clone(), arguments));
209
210 yield Ok(RawStreamingChoice::ToolCallDelta {
211 id,
212 internal_call_id,
213 content: ToolCallDeltaContent::Name(name),
214 });
215 },
216
217 StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
218 let Some(message) = &delta.message else { continue; };
219 let Some(tool_calls) = &message.tool_calls else { continue; };
220 let Some(function) = &tool_calls.function else { continue; };
221 let Some(arguments) = function.arguments.clone() else { continue; };
222
223 let Some(tc) = current_tool_call.clone() else { continue; };
224 current_tool_call = Some((tc.0.clone(), tc.1.clone(), tc.2, format!("{}{}", tc.3, arguments)));
225
226 yield Ok(RawStreamingChoice::ToolCallDelta {
228 id: tc.0,
229 internal_call_id: tc.1,
230 content: ToolCallDeltaContent::Delta(arguments),
231 });
232 },
233
234 StreamingEvent::ToolCallEnd => {
235 let Some(tc) = current_tool_call.clone() else { continue; };
236 let Ok(args) = json_utils::parse_tool_arguments(&tc.3) else { continue; };
237
238 tool_calls.push(ToolCall {
239 id: Some(tc.0.clone()),
240 r#type: Some(ToolType::Function),
241 function: Some(ToolCallFunction {
242 name: tc.2.clone(),
243 arguments: args.clone()
244 })
245 });
246
247 let raw_tool_call = RawStreamingToolCall::new(tc.0, tc.2, args)
248 .with_internal_call_id(tc.1);
249 yield Ok(RawStreamingChoice::ToolCall(raw_tool_call));
250
251 current_tool_call = None;
252 },
253
254 _ => {}
255 }
256 },
257 Err(crate::http_client::Error::StreamEnded) => {
258 break;
259 }
260 Err(err) => {
261 tracing::error!(?err, "SSE error");
262 yield Err(CompletionError::ProviderError(err.to_string()));
263 break;
264 }
265 }
266 }
267
268 event_source.close();
270
271 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
272 usage: final_usage.unwrap_or_default()
273 }))
274 }.instrument(span);
275
276 Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
277 stream,
278 )))
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use serde_json::json;
286
287 #[test]
288 fn test_message_content_delta_deserialization() {
289 let json = json!({
290 "type": "content-delta",
291 "delta": {
292 "message": {
293 "content": {
294 "text": "Hello world"
295 }
296 }
297 }
298 });
299
300 let event: StreamingEvent = serde_json::from_value(json).unwrap();
301 match event {
302 StreamingEvent::ContentDelta { delta } => {
303 assert!(delta.is_some());
304 let message = delta.unwrap().message.unwrap();
305 let content = message.content.unwrap();
306 assert_eq!(content.text, Some("Hello world".to_string()));
307 }
308 _ => panic!("Expected ContentDelta"),
309 }
310 }
311
312 #[test]
313 fn test_tool_call_start_deserialization() {
314 let json = json!({
315 "type": "tool-call-start",
316 "delta": {
317 "message": {
318 "tool_calls": {
319 "id": "call_123",
320 "function": {
321 "name": "get_weather",
322 "arguments": "{"
323 }
324 }
325 }
326 }
327 });
328
329 let event: StreamingEvent = serde_json::from_value(json).unwrap();
330 match event {
331 StreamingEvent::ToolCallStart { delta } => {
332 assert!(delta.is_some());
333 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
334 assert_eq!(tool_call.id, Some("call_123".to_string()));
335 assert_eq!(
336 tool_call.function.unwrap().name,
337 Some("get_weather".to_string())
338 );
339 }
340 _ => panic!("Expected ToolCallStart"),
341 }
342 }
343
344 #[test]
345 fn test_tool_call_delta_deserialization() {
346 let json = json!({
347 "type": "tool-call-delta",
348 "delta": {
349 "message": {
350 "tool_calls": {
351 "function": {
352 "arguments": "\"location\""
353 }
354 }
355 }
356 }
357 });
358
359 let event: StreamingEvent = serde_json::from_value(json).unwrap();
360 match event {
361 StreamingEvent::ToolCallDelta { delta } => {
362 assert!(delta.is_some());
363 let tool_call = delta.unwrap().message.unwrap().tool_calls.unwrap();
364 let function = tool_call.function.unwrap();
365 assert_eq!(function.arguments, Some("\"location\"".to_string()));
366 }
367 _ => panic!("Expected ToolCallDelta"),
368 }
369 }
370
371 #[test]
372 fn test_tool_call_end_deserialization() {
373 let json = json!({
374 "type": "tool-call-end"
375 });
376
377 let event: StreamingEvent = serde_json::from_value(json).unwrap();
378 match event {
379 StreamingEvent::ToolCallEnd => {
380 }
382 _ => panic!("Expected ToolCallEnd"),
383 }
384 }
385
386 #[test]
387 fn test_message_end_with_usage_deserialization() {
388 let json = json!({
389 "type": "message-end",
390 "delta": {
391 "usage": {
392 "tokens": {
393 "input_tokens": 100,
394 "output_tokens": 50
395 }
396 }
397 }
398 });
399
400 let event: StreamingEvent = serde_json::from_value(json).unwrap();
401 match event {
402 StreamingEvent::MessageEnd { delta } => {
403 assert!(delta.is_some());
404 let usage = delta.unwrap().usage.unwrap();
405 let tokens = usage.tokens.unwrap();
406 assert_eq!(tokens.input_tokens, Some(100.0));
407 assert_eq!(tokens.output_tokens, Some(50.0));
408 }
409 _ => panic!("Expected MessageEnd"),
410 }
411 }
412
413 #[test]
414 fn test_streaming_event_order() {
415 let events = vec![
417 json!({"type": "message-start"}),
418 json!({"type": "content-start"}),
419 json!({
420 "type": "content-delta",
421 "delta": {
422 "message": {
423 "content": {
424 "text": "Sure, "
425 }
426 }
427 }
428 }),
429 json!({
430 "type": "content-delta",
431 "delta": {
432 "message": {
433 "content": {
434 "text": "I can help with that."
435 }
436 }
437 }
438 }),
439 json!({"type": "content-end"}),
440 json!({"type": "tool-plan"}),
441 json!({
442 "type": "tool-call-start",
443 "delta": {
444 "message": {
445 "tool_calls": {
446 "id": "call_abc",
447 "function": {
448 "name": "search",
449 "arguments": ""
450 }
451 }
452 }
453 }
454 }),
455 json!({
456 "type": "tool-call-delta",
457 "delta": {
458 "message": {
459 "tool_calls": {
460 "function": {
461 "arguments": "{\"query\":"
462 }
463 }
464 }
465 }
466 }),
467 json!({
468 "type": "tool-call-delta",
469 "delta": {
470 "message": {
471 "tool_calls": {
472 "function": {
473 "arguments": "\"Rust\"}"
474 }
475 }
476 }
477 }
478 }),
479 json!({"type": "tool-call-end"}),
480 json!({
481 "type": "message-end",
482 "delta": {
483 "usage": {
484 "tokens": {
485 "input_tokens": 50,
486 "output_tokens": 25
487 }
488 }
489 }
490 }),
491 ];
492
493 for (i, event_json) in events.iter().enumerate() {
494 let result = serde_json::from_value::<StreamingEvent>(event_json.clone());
495 assert!(
496 result.is_ok(),
497 "Failed to deserialize event at index {}: {:?}",
498 i,
499 result.err()
500 );
501 }
502 }
503}