use bytes::Bytes;
use futures::Stream;
use std::pin::Pin;
use crate::core::providers::base::sse::{SSETransformer, UnifiedSSEStream};
use crate::core::providers::unified_provider::ProviderError;
use crate::core::types::responses::{ChatChunk, ChatDelta, ChatStreamChoice};
#[derive(Debug, Clone)]
pub struct ReplicateTransformer {
provider_name: &'static str,
}
impl ReplicateTransformer {
pub fn new() -> Self {
Self {
provider_name: "replicate",
}
}
}
impl Default for ReplicateTransformer {
fn default() -> Self {
Self::new()
}
}
impl SSETransformer for ReplicateTransformer {
fn provider_name(&self) -> &'static str {
self.provider_name
}
fn is_end_marker(&self, data: &str) -> bool {
data.trim() == "[DONE]" || data.trim() == "done"
}
fn transform_chunk(&self, data: &str) -> Result<Option<ChatChunk>, ProviderError> {
if data.trim().is_empty() {
return Ok(None);
}
if self.is_end_marker(data) {
return Ok(None);
}
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(output) = json.get("output") {
if let Some(arr) = output.as_array() {
let text: String = arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join("");
if !text.is_empty() {
return Ok(Some(create_chat_chunk(&text)));
}
} else if let Some(text) = output.as_str() {
return Ok(Some(create_chat_chunk(text)));
}
}
if let Some(text) = json.get("text").and_then(|v| v.as_str()) {
return Ok(Some(create_chat_chunk(text)));
}
if let Some(text) = json.get("content").and_then(|v| v.as_str()) {
return Ok(Some(create_chat_chunk(text)));
}
if let Some(error) = json.get("error").and_then(|v| v.as_str()) {
return Err(ProviderError::replicate_prediction_failed(error));
}
}
Ok(Some(create_chat_chunk(data)))
}
}
fn create_chat_chunk(text: &str) -> ChatChunk {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_secs() as i64);
ChatChunk {
id: format!("chatcmpl-replicate-{}", timestamp),
object: "chat.completion.chunk".to_string(),
created: timestamp,
model: "replicate".to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: 0,
delta: ChatDelta {
role: Some(crate::core::types::message::MessageRole::Assistant),
content: Some(text.to_string()),
thinking: None,
tool_calls: None,
function_call: None,
},
logprobs: None,
finish_reason: None,
}],
usage: None,
}
}
pub type ReplicateStream = UnifiedSSEStream<
Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
ReplicateTransformer,
>;
pub fn create_replicate_stream(
stream: impl Stream<Item = Result<Bytes, reqwest::Error>> + Send + 'static,
) -> ReplicateStream {
let transformer = ReplicateTransformer::new();
UnifiedSSEStream::new(Box::pin(stream), transformer)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_replicate_transformer_creation() {
let transformer = ReplicateTransformer::new();
assert_eq!(transformer.provider_name(), "replicate");
}
#[test]
fn test_transform_raw_text() {
let transformer = ReplicateTransformer::new();
let result = transformer.transform_chunk("Hello, world!");
assert!(result.is_ok());
let chunk = result.unwrap().unwrap();
assert_eq!(
chunk.choices[0].delta.content,
Some("Hello, world!".to_string())
);
}
#[test]
fn test_transform_empty_line() {
let transformer = ReplicateTransformer::new();
let result = transformer.transform_chunk("");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_transform_done_marker() {
let transformer = ReplicateTransformer::new();
let result1 = transformer.transform_chunk("[DONE]");
assert!(result1.is_ok());
assert!(result1.unwrap().is_none());
let result2 = transformer.transform_chunk("done");
assert!(result2.is_ok());
assert!(result2.unwrap().is_none());
}
#[test]
fn test_transform_json_output_string() {
let transformer = ReplicateTransformer::new();
let json = r#"{"output": "Hello, world!"}"#;
let result = transformer.transform_chunk(json);
assert!(result.is_ok());
let chunk = result.unwrap().unwrap();
assert_eq!(
chunk.choices[0].delta.content,
Some("Hello, world!".to_string())
);
}
#[test]
fn test_transform_json_output_array() {
let transformer = ReplicateTransformer::new();
let json = r#"{"output": ["Hello", ", ", "world", "!"]}"#;
let result = transformer.transform_chunk(json);
assert!(result.is_ok());
let chunk = result.unwrap().unwrap();
assert_eq!(
chunk.choices[0].delta.content,
Some("Hello, world!".to_string())
);
}
#[test]
fn test_transform_json_text_field() {
let transformer = ReplicateTransformer::new();
let json = r#"{"text": "Sample text"}"#;
let result = transformer.transform_chunk(json);
assert!(result.is_ok());
let chunk = result.unwrap().unwrap();
assert_eq!(
chunk.choices[0].delta.content,
Some("Sample text".to_string())
);
}
#[test]
fn test_transform_json_error() {
let transformer = ReplicateTransformer::new();
let json = r#"{"error": "Something went wrong"}"#;
let result = transformer.transform_chunk(json);
assert!(result.is_err());
}
#[test]
fn test_create_chat_chunk() {
let chunk = create_chat_chunk("Test content");
assert!(chunk.id.starts_with("chatcmpl-replicate-"));
assert_eq!(chunk.object, "chat.completion.chunk");
assert_eq!(chunk.model, "replicate");
assert_eq!(chunk.choices.len(), 1);
assert_eq!(
chunk.choices[0].delta.role,
Some(crate::core::types::message::MessageRole::Assistant)
);
assert_eq!(
chunk.choices[0].delta.content,
Some("Test content".to_string())
);
}
#[test]
fn test_transformer_default() {
let transformer = ReplicateTransformer::default();
assert_eq!(transformer.provider_name(), "replicate");
}
#[test]
fn test_transform_whitespace_only() {
let transformer = ReplicateTransformer::new();
let result = transformer.transform_chunk(" ");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_is_end_marker() {
let transformer = ReplicateTransformer::new();
assert!(transformer.is_end_marker("[DONE]"));
assert!(transformer.is_end_marker("done"));
assert!(transformer.is_end_marker(" [DONE] "));
assert!(!transformer.is_end_marker("data"));
}
}