Skip to main content

agent_core_runtime/client/providers/cohere/
mod.rs

1//! Cohere API provider implementation.
2
3mod sse;
4mod types;
5
6use async_stream::stream;
7use futures::Stream;
8
9use crate::client::error::LlmError;
10use crate::client::http::HttpClient;
11use crate::client::models::{Message, MessageOptions, StreamEvent};
12use crate::client::traits::LlmProvider;
13use std::future::Future;
14use std::pin::Pin;
15
16// =============================================================================
17// Constants
18// =============================================================================
19
20/// Error code for SSE decoding errors.
21const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
22
23/// Error message for invalid UTF-8 in stream.
24const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
25
26// =============================================================================
27// Provider
28// =============================================================================
29
30/// Cohere API provider.
31///
32/// Supports Cohere's chat API with tool calling capabilities.
33pub struct CohereProvider {
34    /// Cohere API key.
35    api_key: String,
36    /// Model identifier (e.g., "command-r-plus", "command-r").
37    model: String,
38}
39
40impl CohereProvider {
41    /// Create a new Cohere provider with API key and model.
42    pub fn new(api_key: String, model: String) -> Self {
43        Self { api_key, model }
44    }
45
46    /// Returns the model identifier.
47    pub fn model(&self) -> &str {
48        &self.model
49    }
50}
51
52impl LlmProvider for CohereProvider {
53    fn send_msg(
54        &self,
55        client: &HttpClient,
56        messages: &[Message],
57        options: &MessageOptions,
58    ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
59        // Clone data for the async block
60        let client = client.clone();
61        let api_key = self.api_key.clone();
62        let model = options.model.as_deref().unwrap_or(&self.model).to_string();
63        let messages = messages.to_vec();
64        let options = options.clone();
65
66        Box::pin(async move {
67            // Build request body
68            let body = types::build_request_body(&messages, &options, &model)?;
69
70            // Get headers
71            let headers = types::get_request_headers(&api_key)?;
72            let headers_ref: Vec<(&str, &str)> = headers
73                .iter()
74                .map(|(k, v)| (*k, v.as_str()))
75                .collect();
76
77            // Get the API URL
78            let url = types::get_api_url();
79
80            // Make the API call
81            let response = client.post(&url, &headers_ref, &body).await?;
82
83            // Parse and return the response
84            types::parse_response(&response)
85        })
86    }
87
88    fn send_msg_stream(
89        &self,
90        client: &HttpClient,
91        messages: &[Message],
92        options: &MessageOptions,
93    ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
94        // Clone data for the async block
95        let client = client.clone();
96        let api_key = self.api_key.clone();
97        let model = options.model.as_deref().unwrap_or(&self.model).to_string();
98        let messages = messages.to_vec();
99        let options = options.clone();
100
101        Box::pin(async move {
102            // Build streaming request body
103            let body = types::build_streaming_request_body(&messages, &options, &model)?;
104
105            // Get headers
106            let headers = types::get_request_headers(&api_key)?;
107            let headers_ref: Vec<(&str, &str)> = headers
108                .iter()
109                .map(|(k, v)| (*k, v.as_str()))
110                .collect();
111
112            // Get the API URL
113            let url = types::get_api_url();
114
115            // Make the streaming API call
116            let byte_stream = client.post_stream(&url, &headers_ref, &body).await?;
117
118            // Convert byte stream to SSE events stream
119            use futures::StreamExt;
120            let event_stream = stream! {
121                let mut buffer = String::new();
122                let mut byte_stream = byte_stream;
123                let mut message_started = false;
124                let mut stream_state = sse::StreamState::default();
125
126                while let Some(chunk_result) = byte_stream.next().await {
127                    match chunk_result {
128                        Ok(bytes) => {
129                            // Append new bytes to buffer
130                            if let Ok(text) = std::str::from_utf8(&bytes) {
131                                buffer.push_str(text);
132                            } else {
133                                yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
134                                break;
135                            }
136
137                            // Parse complete SSE events from buffer
138                            let (events, remaining) = sse::parse_sse_chunk(&buffer);
139                            buffer = remaining;
140
141                            // Convert and yield each SSE event
142                            for sse_event in events {
143                                match sse::parse_stream_event(&sse_event, &mut stream_state) {
144                                    Ok(stream_events) => {
145                                        // Emit MessageStart on first content
146                                        if !message_started && !stream_events.is_empty() {
147                                            message_started = true;
148                                            yield Ok(StreamEvent::MessageStart {
149                                                message_id: String::new(),
150                                                model: model.clone(),
151                                            });
152                                        }
153
154                                        for stream_event in stream_events {
155                                            yield Ok(stream_event);
156                                        }
157                                    }
158                                    Err(e) => {
159                                        yield Err(e);
160                                        return;
161                                    }
162                                }
163                            }
164                        }
165                        Err(e) => {
166                            yield Err(e);
167                            break;
168                        }
169                    }
170                }
171
172                // Emit MessageStop at the end
173                if message_started {
174                    yield Ok(StreamEvent::MessageStop);
175                }
176            };
177
178            Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
179        })
180    }
181}