rig/providers/cohere/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::Usage;
4use crate::streaming::RawStreamingChoice;
5use crate::{json_utils, streaming};
6use async_stream::stream;
7use futures::StreamExt;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11#[derive(Debug, Deserialize)]
12#[serde(rename_all = "kebab-case", tag = "type")]
13enum StreamingEvent {
14 MessageStart,
15 ContentStart,
16 ContentDelta { delta: Option<Delta> },
17 ContentEnd,
18 ToolPlan,
19 ToolCallStart { delta: Option<Delta> },
20 ToolCallDelta { delta: Option<Delta> },
21 ToolCallEnd,
22 MessageEnd { delta: Option<MessageEndDelta> },
23}
24
25#[derive(Debug, Deserialize)]
26struct MessageContentDelta {
27 text: Option<String>,
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageToolFunctionDelta {
32 name: Option<String>,
33 arguments: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolCallDelta {
38 id: Option<String>,
39 function: Option<MessageToolFunctionDelta>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageDelta {
44 content: Option<MessageContentDelta>,
45 tool_calls: Option<MessageToolCallDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct Delta {
50 message: Option<MessageDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MessageEndDelta {
55 usage: Option<Usage>,
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60 pub usage: Option<Usage>,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64 fn token_usage(&self) -> Option<crate::completion::Usage> {
65 let tokens = self
66 .usage
67 .clone()
68 .and_then(|response| response.tokens)
69 .map(|tokens| {
70 (
71 tokens.input_tokens.map(|x| x as u64),
72 tokens.output_tokens.map(|y| y as u64),
73 )
74 });
75 let Some((Some(input), Some(output))) = tokens else {
76 return None;
77 };
78 let mut usage = crate::completion::Usage::new();
79 usage.input_tokens = input;
80 usage.output_tokens = output;
81 usage.total_tokens = input + output;
82
83 Some(usage)
84 }
85}
86
87impl CompletionModel {
88 pub(crate) async fn stream(
89 &self,
90 request: CompletionRequest,
91 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
92 {
93 let request = self.create_completion_request(request)?;
94 let request = json_utils::merge(request, json!({"stream": true}));
95
96 tracing::debug!(
97 "Cohere request: {}",
98 serde_json::to_string_pretty(&request)?
99 );
100
101 let response = self.client.post("/v2/chat").json(&request).send().await?;
102
103 if !response.status().is_success() {
104 return Err(CompletionError::ProviderError(format!(
105 "{}: {}",
106 response.status(),
107 response.text().await?
108 )));
109 }
110
111 let stream = Box::pin(stream! {
112 let mut stream = response.bytes_stream();
113 let mut current_tool_call: Option<(String, String, String)> = None;
114
115 while let Some(chunk_result) = stream.next().await {
116 let chunk = match chunk_result {
117 Ok(c) => c,
118 Err(e) => {
119 yield Err(CompletionError::from(e));
120 break;
121 }
122 };
123
124 let text = match String::from_utf8(chunk.to_vec()) {
125 Ok(t) => t,
126 Err(e) => {
127 yield Err(CompletionError::ResponseError(e.to_string()));
128 break;
129 }
130 };
131
132 for line in text.lines() {
133
134 let Some(line) = line.strip_prefix("data: ") else {
135 continue;
136 };
137
138 let event = {
139 let result = serde_json::from_str::<StreamingEvent>(line);
140
141 let Ok(event) = result else {
142 continue;
143 };
144
145 event
146 };
147
148 match event {
149 StreamingEvent::ContentDelta { delta: Some(delta) } => {
150 let Some(message) = &delta.message else { continue; };
151 let Some(content) = &message.content else { continue; };
152 let Some(text) = &content.text else { continue; };
153
154 yield Ok(RawStreamingChoice::Message(text.clone()));
155 },
156 StreamingEvent::MessageEnd {delta: Some(delta)} => {
157 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
158 usage: delta.usage.clone()
159 }));
160 },
161 StreamingEvent::ToolCallStart { delta: Some(delta)} => {
162 let Some(message) = &delta.message else { continue; };
165 let Some(tool_calls) = &message.tool_calls else { continue; };
166 let Some(id) = tool_calls.id.clone() else { continue; };
167 let Some(function) = &tool_calls.function else { continue; };
168 let Some(name) = function.name.clone() else { continue; };
169 let Some(arguments) = function.arguments.clone() else { continue; };
170
171 current_tool_call = Some((id, name, arguments));
172 },
173 StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
174 let Some(message) = &delta.message else { continue; };
177 let Some(tool_calls) = &message.tool_calls else { continue; };
178 let Some(function) = &tool_calls.function else { continue; };
179 let Some(arguments) = function.arguments.clone() else { continue; };
180
181 if let Some(tc) = current_tool_call.clone() {
182 current_tool_call = Some((
183 tc.0,
184 tc.1,
185 format!("{}{}", tc.2, arguments)
186 ));
187 };
188 },
189 StreamingEvent::ToolCallEnd => {
190 let Some(tc) = current_tool_call.clone() else { continue; };
191
192 let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
193
194 yield Ok(RawStreamingChoice::ToolCall {
195 id: tc.0,
196 name: tc.1,
197 arguments: args,
198 call_id: None
199 });
200
201 current_tool_call = None;
202 },
203 _ => {}
204 };
205 }
206 }
207 });
208
209 Ok(streaming::StreamingCompletionResponse::stream(stream))
210 }
211}