use crate::adapter::adapters::support::{StreamerCapturedData, StreamerOptions};
use crate::adapter::cohere::CohereAdapter;
use crate::adapter::inter_stream::{InterStreamEnd, InterStreamEvent};
use crate::chat::ChatOptionsSet;
use crate::webc::WebStream;
use crate::{Error, ModelIden, Result};
use serde::Deserialize;
use serde_json::Value;
use std::pin::Pin;
use std::task::{Context, Poll};
use value_ext::JsonValueExt;
pub struct CohereStreamer {
inner: WebStream,
options: StreamerOptions,
done: bool,
captured_data: StreamerCapturedData,
}
impl CohereStreamer {
pub fn new(inner: WebStream, model_iden: ModelIden, options_set: ChatOptionsSet<'_, '_>) -> Self {
Self {
inner,
done: false,
options: StreamerOptions::new(model_iden, options_set),
captured_data: Default::default(),
}
}
}
#[derive(Deserialize, Debug)]
struct CohereStreamMessage {
#[allow(unused)]
is_finished: bool,
event_type: String,
text: Option<String>,
response: Option<CohereStreamMessageResponse>,
}
#[derive(Deserialize, Debug)]
struct CohereStreamMessageResponse {
meta: Option<Value>,
}
impl futures::Stream for CohereStreamer {
type Item = Result<InterStreamEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
while let Poll::Ready(item) = Pin::new(&mut self.inner).poll_next(cx) {
match item {
Some(Ok(raw_string)) => {
let cohere_message =
serde_json::from_str::<CohereStreamMessage>(&raw_string).map_err(|serde_error| {
Error::StreamParse {
model_iden: self.options.model_iden.clone(),
serde_error,
}
});
match cohere_message {
Ok(cohere_message) => {
let inter_event = match cohere_message.event_type.as_str() {
"stream-start" => InterStreamEvent::Start,
"text-generation" => {
if let Some(content) = cohere_message.text {
if self.options.capture_content {
match self.captured_data.content {
Some(ref mut c) => c.push_str(&content),
None => self.captured_data.content = Some(content.clone()),
}
}
InterStreamEvent::Chunk(content)
} else {
continue;
}
}
"stream-end" => {
let meta = cohere_message.response.and_then(|r| r.meta);
let captured_usage = if self.options.capture_usage {
meta.and_then(|mut v| v.x_take("tokens").ok())
.map(CohereAdapter::into_usage)
.map(|mut usage| {
if usage.prompt_tokens.is_some() || usage.completion_tokens.is_some() {
usage.total_tokens = Some(
usage.prompt_tokens.unwrap_or(0)
+ usage.completion_tokens.unwrap_or(0),
);
}
usage
})
} else {
None
};
let inter_stream_end = InterStreamEnd {
captured_usage,
captured_text_content: self.captured_data.content.take(),
captured_reasoning_content: self.captured_data.reasoning_content.take(),
captured_tool_calls: self.captured_data.tool_calls.take(),
captured_thought_signatures: None,
};
InterStreamEvent::End(inter_stream_end)
}
_ => continue, };
return Poll::Ready(Some(Ok(inter_event)));
}
Err(err) => {
tracing::error!("Cohere Adapter Stream Error: {}", err);
return Poll::Ready(Some(Err(err)));
}
}
}
Some(Err(err)) => {
tracing::error!("Cohere Adapter Stream Error: {}", err);
return Poll::Ready(Some(Err(Error::WebStream {
model_iden: self.options.model_iden.clone(),
cause: err.to_string(),
error: err,
})));
}
None => {
self.done = true;
return Poll::Ready(None);
}
}
}
Poll::Pending
}
}