mod sse;
mod types;
use async_stream::stream;
use futures::Stream;
use crate::client::error::LlmError;
use crate::client::http::HttpClient;
use crate::client::models::{Message, MessageOptions, StreamEvent};
use crate::client::traits::{LlmProvider, StreamMsgFuture};
use std::future::Future;
use std::pin::Pin;
const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
pub struct CohereProvider {
api_key: String,
model: String,
}
impl CohereProvider {
pub fn new(api_key: String, model: String) -> Self {
Self { api_key, model }
}
pub fn model(&self) -> &str {
&self.model
}
}
impl LlmProvider for CohereProvider {
fn send_msg(
&self,
client: &HttpClient,
messages: &[Message],
options: &MessageOptions,
) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
let client = client.clone();
let api_key = self.api_key.clone();
let model = options.model.as_deref().unwrap_or(&self.model).to_string();
let messages = messages.to_vec();
let options = options.clone();
Box::pin(async move {
let body = types::build_request_body(&messages, &options, &model)?;
let headers = types::get_request_headers(&api_key)?;
let headers_ref: Vec<(&str, &str)> =
headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
let url = types::get_api_url();
let response = client.post(&url, &headers_ref, &body).await?;
types::parse_response(&response)
})
}
fn send_msg_stream(
&self,
client: &HttpClient,
messages: &[Message],
options: &MessageOptions,
) -> StreamMsgFuture {
let client = client.clone();
let api_key = self.api_key.clone();
let model = options.model.as_deref().unwrap_or(&self.model).to_string();
let messages = messages.to_vec();
let options = options.clone();
Box::pin(async move {
let body = types::build_streaming_request_body(&messages, &options, &model)?;
let headers = types::get_request_headers(&api_key)?;
let headers_ref: Vec<(&str, &str)> =
headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
let url = types::get_api_url();
let byte_stream = client.post_stream(&url, &headers_ref, &body).await?;
use futures::StreamExt;
let event_stream = stream! {
let mut buffer = String::new();
let mut byte_stream = byte_stream;
let mut message_started = false;
let mut stream_state = sse::StreamState::default();
while let Some(chunk_result) = byte_stream.next().await {
match chunk_result {
Ok(bytes) => {
if let Ok(text) = std::str::from_utf8(&bytes) {
buffer.push_str(text);
} else {
yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
break;
}
let (events, remaining) = sse::parse_sse_chunk(&buffer);
buffer = remaining;
for sse_event in events {
match sse::parse_stream_event(&sse_event, &mut stream_state) {
Ok(stream_events) => {
if !message_started && !stream_events.is_empty() {
message_started = true;
yield Ok(StreamEvent::MessageStart {
message_id: String::new(),
model: model.clone(),
});
}
for stream_event in stream_events {
yield Ok(stream_event);
}
}
Err(e) => {
yield Err(e);
return;
}
}
}
}
Err(e) => {
yield Err(e);
break;
}
}
}
if message_started {
yield Ok(StreamEvent::MessageStop);
}
};
Ok(Box::pin(event_stream)
as Pin<
Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>,
>)
})
}
}