rig/providers/cohere/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::providers::cohere::CompletionModel;
3use crate::providers::cohere::completion::Usage;
4use crate::streaming::RawStreamingChoice;
5use crate::{json_utils, streaming};
6use async_stream::stream;
7use futures::StreamExt;
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11#[derive(Debug, Deserialize)]
12#[serde(rename_all = "kebab-case", tag = "type")]
13enum StreamingEvent {
14    MessageStart,
15    ContentStart,
16    ContentDelta { delta: Option<Delta> },
17    ContentEnd,
18    ToolPlan,
19    ToolCallStart { delta: Option<Delta> },
20    ToolCallDelta { delta: Option<Delta> },
21    ToolCallEnd,
22    MessageEnd { delta: Option<MessageEndDelta> },
23}
24
25#[derive(Debug, Deserialize)]
26struct MessageContentDelta {
27    text: Option<String>,
28}
29
30#[derive(Debug, Deserialize)]
31struct MessageToolFunctionDelta {
32    name: Option<String>,
33    arguments: Option<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct MessageToolCallDelta {
38    id: Option<String>,
39    function: Option<MessageToolFunctionDelta>,
40}
41
42#[derive(Debug, Deserialize)]
43struct MessageDelta {
44    content: Option<MessageContentDelta>,
45    tool_calls: Option<MessageToolCallDelta>,
46}
47
48#[derive(Debug, Deserialize)]
49struct Delta {
50    message: Option<MessageDelta>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MessageEndDelta {
55    usage: Option<Usage>,
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59pub struct StreamingCompletionResponse {
60    pub usage: Option<Usage>,
61}
62
63impl GetTokenUsage for StreamingCompletionResponse {
64    fn token_usage(&self) -> Option<crate::completion::Usage> {
65        let tokens = self
66            .usage
67            .clone()
68            .and_then(|response| response.tokens)
69            .map(|tokens| {
70                (
71                    tokens.input_tokens.map(|x| x as u64),
72                    tokens.output_tokens.map(|y| y as u64),
73                )
74            });
75        let Some((Some(input), Some(output))) = tokens else {
76            return None;
77        };
78        let mut usage = crate::completion::Usage::new();
79        usage.input_tokens = input;
80        usage.output_tokens = output;
81        usage.total_tokens = input + output;
82
83        Some(usage)
84    }
85}
86
87impl CompletionModel {
88    pub(crate) async fn stream(
89        &self,
90        request: CompletionRequest,
91    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
92    {
93        let request = self.create_completion_request(request)?;
94        let request = json_utils::merge(request, json!({"stream": true}));
95
96        tracing::debug!(
97            "Cohere request: {}",
98            serde_json::to_string_pretty(&request)?
99        );
100
101        let response = self.client.post("/v2/chat").json(&request).send().await?;
102
103        if !response.status().is_success() {
104            return Err(CompletionError::ProviderError(format!(
105                "{}: {}",
106                response.status(),
107                response.text().await?
108            )));
109        }
110
111        let stream = Box::pin(stream! {
112            let mut stream = response.bytes_stream();
113            let mut current_tool_call: Option<(String, String, String)> = None;
114
115            while let Some(chunk_result) = stream.next().await {
116               let chunk = match chunk_result {
117                    Ok(c) => c,
118                    Err(e) => {
119                        yield Err(CompletionError::from(e));
120                        break;
121                    }
122                };
123
124               let text = match String::from_utf8(chunk.to_vec()) {
125                    Ok(t) => t,
126                    Err(e) => {
127                        yield Err(CompletionError::ResponseError(e.to_string()));
128                        break;
129                    }
130               };
131
132                for line in text.lines() {
133
134                    let Some(line) = line.strip_prefix("data: ") else {
135                        continue;
136                    };
137
138                    let event = {
139                       let result = serde_json::from_str::<StreamingEvent>(line);
140
141                       let Ok(event) = result else {
142                           continue;
143                       };
144
145                        event
146                    };
147
148                    match event {
149                        StreamingEvent::ContentDelta { delta: Some(delta) } => {
150                            let Some(message) = &delta.message else { continue; };
151                            let Some(content) = &message.content else { continue; };
152                            let Some(text) = &content.text else { continue; };
153
154                            yield Ok(RawStreamingChoice::Message(text.clone()));
155                        },
156                        StreamingEvent::MessageEnd {delta: Some(delta)} => {
157                            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
158                                usage: delta.usage.clone()
159                            }));
160                        },
161                        StreamingEvent::ToolCallStart { delta: Some(delta)} => {
162                            // Skip the delta if there's any missing information,
163                            // though this *should* all be present
164                            let Some(message) = &delta.message else { continue; };
165                            let Some(tool_calls) = &message.tool_calls else { continue; };
166                            let Some(id) = tool_calls.id.clone() else { continue; };
167                            let Some(function) = &tool_calls.function else { continue; };
168                            let Some(name) = function.name.clone() else { continue; };
169                            let Some(arguments) = function.arguments.clone() else { continue; };
170
171                            current_tool_call = Some((id, name, arguments));
172                        },
173                        StreamingEvent::ToolCallDelta { delta: Some(delta)} => {
174                            // Skip the delta if there's any missing information,
175                            // though this *should* all be present
176                            let Some(message) = &delta.message else { continue; };
177                            let Some(tool_calls) = &message.tool_calls else { continue; };
178                            let Some(function) = &tool_calls.function else { continue; };
179                            let Some(arguments) = function.arguments.clone() else { continue; };
180
181                            if let Some(tc) = current_tool_call.clone() {
182                                current_tool_call = Some((
183                                    tc.0,
184                                    tc.1,
185                                    format!("{}{}", tc.2, arguments)
186                                ));
187                            };
188                        },
189                        StreamingEvent::ToolCallEnd => {
190                            let Some(tc) = current_tool_call.clone() else { continue; };
191
192                            let Ok(args) = serde_json::from_str(&tc.2) else { continue; };
193
194                            yield Ok(RawStreamingChoice::ToolCall {
195                                id: tc.0,
196                                name: tc.1,
197                                arguments: args,
198                                call_id: None
199                            });
200
201                            current_tool_call = None;
202                        },
203                        _ => {}
204                    };
205                }
206            }
207        });
208
209        Ok(streaming::StreamingCompletionResponse::stream(stream))
210    }
211}