agent_core_runtime/client/providers/cohere/
mod.rs1mod sse;
4mod types;
5
6use async_stream::stream;
7use futures::Stream;
8
9use crate::client::error::LlmError;
10use crate::client::http::HttpClient;
11use crate::client::models::{Message, MessageOptions, StreamEvent};
12use crate::client::traits::LlmProvider;
13use std::future::Future;
14use std::pin::Pin;
15
16const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
22
23const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
25
26pub struct CohereProvider {
34 api_key: String,
36 model: String,
38}
39
40impl CohereProvider {
41 pub fn new(api_key: String, model: String) -> Self {
43 Self { api_key, model }
44 }
45
46 pub fn model(&self) -> &str {
48 &self.model
49 }
50}
51
52impl LlmProvider for CohereProvider {
53 fn send_msg(
54 &self,
55 client: &HttpClient,
56 messages: &[Message],
57 options: &MessageOptions,
58 ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
59 let client = client.clone();
61 let api_key = self.api_key.clone();
62 let model = options.model.as_deref().unwrap_or(&self.model).to_string();
63 let messages = messages.to_vec();
64 let options = options.clone();
65
66 Box::pin(async move {
67 let body = types::build_request_body(&messages, &options, &model)?;
69
70 let headers = types::get_request_headers(&api_key)?;
72 let headers_ref: Vec<(&str, &str)> = headers
73 .iter()
74 .map(|(k, v)| (*k, v.as_str()))
75 .collect();
76
77 let url = types::get_api_url();
79
80 let response = client.post(&url, &headers_ref, &body).await?;
82
83 types::parse_response(&response)
85 })
86 }
87
88 fn send_msg_stream(
89 &self,
90 client: &HttpClient,
91 messages: &[Message],
92 options: &MessageOptions,
93 ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
94 let client = client.clone();
96 let api_key = self.api_key.clone();
97 let model = options.model.as_deref().unwrap_or(&self.model).to_string();
98 let messages = messages.to_vec();
99 let options = options.clone();
100
101 Box::pin(async move {
102 let body = types::build_streaming_request_body(&messages, &options, &model)?;
104
105 let headers = types::get_request_headers(&api_key)?;
107 let headers_ref: Vec<(&str, &str)> = headers
108 .iter()
109 .map(|(k, v)| (*k, v.as_str()))
110 .collect();
111
112 let url = types::get_api_url();
114
115 let byte_stream = client.post_stream(&url, &headers_ref, &body).await?;
117
118 use futures::StreamExt;
120 let event_stream = stream! {
121 let mut buffer = String::new();
122 let mut byte_stream = byte_stream;
123 let mut message_started = false;
124 let mut stream_state = sse::StreamState::default();
125
126 while let Some(chunk_result) = byte_stream.next().await {
127 match chunk_result {
128 Ok(bytes) => {
129 if let Ok(text) = std::str::from_utf8(&bytes) {
131 buffer.push_str(text);
132 } else {
133 yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
134 break;
135 }
136
137 let (events, remaining) = sse::parse_sse_chunk(&buffer);
139 buffer = remaining;
140
141 for sse_event in events {
143 match sse::parse_stream_event(&sse_event, &mut stream_state) {
144 Ok(stream_events) => {
145 if !message_started && !stream_events.is_empty() {
147 message_started = true;
148 yield Ok(StreamEvent::MessageStart {
149 message_id: String::new(),
150 model: model.clone(),
151 });
152 }
153
154 for stream_event in stream_events {
155 yield Ok(stream_event);
156 }
157 }
158 Err(e) => {
159 yield Err(e);
160 return;
161 }
162 }
163 }
164 }
165 Err(e) => {
166 yield Err(e);
167 break;
168 }
169 }
170 }
171
172 if message_started {
174 yield Ok(StreamEvent::MessageStop);
175 }
176 };
177
178 Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
179 })
180 }
181}