rig/providers/cohere/
streaming.rs1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::Usage;
4use crate::streaming::RawStreamingChoice;
5use crate::{json_utils, streaming};
6use async_stream::stream;
7use futures::StreamExt;
8use reqwest_eventsource::{Event, RequestBuilderExt};
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Deserialize)]
12#[serde(rename_all = "kebab-case", tag = "type")]
13enum StreamingEvent {
14 MessageStart,
15 ContentStart,
16 ContentDelta { delta: Option<Delta> },
17 ContentEnd,
18 ToolPlan,
19 ToolCallStart { delta: Option<Delta> },
20 ToolCallDelta { delta: Option<Delta> },
21 ToolCallEnd,
22 MessageEnd { delta: Option<MessageEndDelta> },
23}
24
25#[derive(Debug, Deserialize)]
26struct MessageContentDelta {
27 text: Option<String>,
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageToolFunctionDelta {
32 name: Option<String>,
33 arguments: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolCallDelta {
38 id: Option<String>,
39 function: Option<MessageToolFunctionDelta>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageDelta {
44 content: Option<MessageContentDelta>,
45 tool_calls: Option<MessageToolCallDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct Delta {
50 message: Option<MessageDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MessageEndDelta {
55 usage: Option<Usage>,
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60 pub usage: Option<Usage>,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64 fn token_usage(&self) -> Option<crate::completion::Usage> {
65 let tokens = self
66 .usage
67 .clone()
68 .and_then(|response| response.tokens)
69 .map(|tokens| {
70 (
71 tokens.input_tokens.map(|x| x as u64),
72 tokens.output_tokens.map(|y| y as u64),
73 )
74 });
75 let Some((Some(input), Some(output))) = tokens else {
76 return None;
77 };
78 let mut usage = crate::completion::Usage::new();
79 usage.input_tokens = input;
80 usage.output_tokens = output;
81 usage.total_tokens = input + output;
82
83 Some(usage)
84 }
85}
86
87impl CompletionModel {
88 pub(crate) async fn stream(
89 &self,
90 request: CompletionRequest,
91 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
92 {
93 let request = self.create_completion_request(request)?;
94 let request = json_utils::merge(request, serde_json::json!({"stream": true}));
95
96 tracing::debug!(
97 "Cohere request: {}",
98 serde_json::to_string_pretty(&request)?
99 );
100
101 let mut event_source = self
102 .client
103 .post("/v2/chat")
104 .json(&request)
105 .eventsource()
106 .map_err(|e| CompletionError::ProviderError(e.to_string()))?;
107
108 let stream = Box::pin(stream! {
109 let mut current_tool_call: Option<(String, String, String)> = None;
110
111 while let Some(event_result) = event_source.next().await {
112 match event_result {
113 Ok(Event::Open) => {
114 tracing::trace!("SSE connection opened");
115 continue;
116 }
117
118 Ok(Event::Message(message)) => {
119 let data_str = message.data.trim();
120 if data_str.is_empty() || data_str == "[DONE]" {
121 continue;
122 }
123
124 let event: StreamingEvent = match serde_json::from_str(data_str) {
125 Ok(ev) => ev,
126 Err(_) => {
127 tracing::debug!("Couldn't parse SSE payload as StreamingEvent");
128 continue;
129 }
130 };
131
132 match event {
133 StreamingEvent::ContentDelta { delta: Some(delta) } => {
134 let Some(message) = &delta.message else { continue; };
135 let Some(content) = &message.content else { continue; };
136 let Some(text) = &content.text else { continue; };
137
138 yield Ok(RawStreamingChoice::Message(text.clone()));
139 },
140
141 StreamingEvent::MessageEnd { delta: Some(delta) } => {
142 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
143 usage: delta.usage.clone()
144 }));
145 },
146
147 StreamingEvent::ToolCallStart { delta: Some(delta) } => {
148 let Some(message) = &delta.message else { continue; };
149 let Some(tool_calls) = &message.tool_calls else { continue; };
150 let Some(id) = tool_calls.id.clone() else { continue; };
151 let Some(function) = &tool_calls.function else { continue; };
152 let Some(name) = function.name.clone() else { continue; };
153 let Some(arguments) = function.arguments.clone() else { continue; };
154
155 current_tool_call = Some((id, name, arguments));
156 },
157
158 StreamingEvent::ToolCallDelta { delta: Some(delta) } => {
159 let Some(message) = &delta.message else { continue; };
160 let Some(tool_calls) = &message.tool_calls else { continue; };
161 let Some(function) = &tool_calls.function else { continue; };
162 let Some(arguments) = function.arguments.clone() else { continue; };
163
164 let Some(tc) = current_tool_call.clone() else { continue; };
165 current_tool_call = Some((tc.0, tc.1, format!("{}{}", tc.2, arguments)));
166 },
167
168 StreamingEvent::ToolCallEnd => {
169 let Some(tc) = current_tool_call.clone() else { continue; };
170 let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
171
172 yield Ok(RawStreamingChoice::ToolCall {
173 id: tc.0,
174 name: tc.1,
175 arguments: args,
176 call_id: None
177 });
178
179 current_tool_call = None;
180 },
181
182 _ => {}
183 }
184 },
185
186 Err(reqwest_eventsource::Error::StreamEnded) => break,
187
188 Err(err) => {
189 tracing::error!(?err, "SSE error");
190 yield Err(CompletionError::ResponseError(err.to_string()));
191 break;
192 }
193 }
194 }
195
196 event_source.close();
197 });
198
199 Ok(streaming::StreamingCompletionResponse::stream(stream))
200 }
201}