use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::{GatewayError, Result};
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct ReplicateProvider {
base: BaseProvider,
pricing_cache: HashMap<String, ModelPricing>,
}
impl ReplicateProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| "https://api.replicate.com".to_string());
let provider = Self {
base: BaseProvider { base_url, ..base },
pricing_cache: Self::initialize_pricing_cache(),
};
info!(
"Replicate provider '{}' initialized successfully",
config.name
);
Ok(provider)
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"meta/llama-2-70b-chat".to_string(),
ModelPricing {
model: "meta/llama-2-70b-chat".to_string(),
input_cost_per_1k: 0.00065,
output_cost_per_1k: 0.00275,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"meta/llama-2-13b-chat".to_string(),
ModelPricing {
model: "meta/llama-2-13b-chat".to_string(),
input_cost_per_1k: 0.0001,
output_cost_per_1k: 0.0005,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"mistralai/mixtral-8x7b-instruct-v0.1".to_string(),
ModelPricing {
model: "mistralai/mixtral-8x7b-instruct-v0.1".to_string(),
input_cost_per_1k: 0.0003,
output_cost_per_1k: 0.001,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"stability-ai/stable-diffusion".to_string(),
ModelPricing {
model: "stability-ai/stable-diffusion".to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.0018, currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"stability-ai/sdxl".to_string(),
ModelPricing {
model: "stability-ai/sdxl".to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.004, currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
fn convert_messages_to_replicate(&self, messages: &[ChatMessage]) -> String {
messages
.iter()
.map(|msg| {
let role = match msg.role {
MessageRole::System => "System",
MessageRole::User => "User",
MessageRole::Assistant => "Assistant",
MessageRole::Tool => "Tool",
MessageRole::Function => "function",
};
let content = match &msg.content {
Some(MessageContent::Text(text)) => text.clone(),
Some(MessageContent::Parts(parts)) => parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>()
.join(" "),
None => String::new(),
};
format!("{}: {}", role, content)
})
.collect::<Vec<String>>()
.join("\n\n")
}
async fn create_prediction(
&self,
model: &str,
input: serde_json::Value,
) -> Result<serde_json::Value> {
let body = json!({
"version": model,
"input": input
});
let url = format!("{}/v1/predictions", self.base.base_url);
let response = self
.base
.client
.post(&url)
.header("Authorization", format!("Token {}", self.base.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let prediction: serde_json::Value = self.base.parse_json_response(response).await?;
Ok(prediction)
}
async fn wait_for_prediction(&self, prediction_id: &str) -> Result<serde_json::Value> {
let url = format!("{}/v1/predictions/{}", self.base.base_url, prediction_id);
for _ in 0..30 {
let response = self
.base
.client
.get(&url)
.header("Authorization", format!("Token {}", self.base.api_key))
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let prediction: serde_json::Value = self.base.parse_json_response(response).await?;
let status = prediction
.get("status")
.and_then(|s| s.as_str())
.unwrap_or("unknown");
match status {
"succeeded" => return Ok(prediction),
"failed" | "canceled" => {
let error = prediction
.get("error")
.and_then(|e| e.as_str())
.unwrap_or("Prediction failed");
return Err(ProviderError::Unknown(error.to_string()).into());
}
"starting" | "processing" => {
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
}
_ => {
return Err(
ProviderError::Unknown(format!("Unknown status: {}", status)).into(),
);
}
}
}
Err(ProviderError::Unknown("Prediction timed out".to_string()).into())
}
}
#[async_trait]
impl Provider for ReplicateProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Custom("replicate".to_string())
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model) || model.contains("/") }
async fn supports_images(&self) -> bool {
true }
async fn supports_embeddings(&self) -> bool {
false }
async fn supports_streaming(&self) -> bool {
false }
async fn list_models(&self) -> Result<Vec<Model>> {
let known_models = vec![
"meta/llama-2-70b-chat",
"meta/llama-2-13b-chat",
"meta/llama-2-7b-chat",
"mistralai/mixtral-8x7b-instruct-v0.1",
"stability-ai/stable-diffusion",
"stability-ai/sdxl",
"openai/whisper",
"salesforce/blip",
];
let models = known_models
.into_iter()
.map(|model| Model {
id: model.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "replicate".to_string(),
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Replicate health check");
let url = format!("{}/v1/account", self.base.base_url);
let response = self
.base
.client
.get(&url)
.header("Authorization", format!("Token {}", self.base.api_key))
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if response.status().is_success() {
Ok(())
} else {
Err(
ProviderError::Unknown(format!("Health check failed: {}", response.status()))
.into(),
)
}
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("Replicate chat completion for model: {}", request.model);
let prompt = self.convert_messages_to_replicate(&request.messages);
let mut input = json!({
"prompt": prompt
});
if let Some(max_tokens) = request.max_tokens {
input["max_new_tokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
input["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
input["top_p"] = json!(top_p);
}
let prediction = self.create_prediction(&request.model, input).await?;
let prediction_id = prediction
.get("id")
.and_then(|id| id.as_str())
.ok_or_else(|| ProviderError::Parsing("No prediction ID in response".to_string()))?;
let completed_prediction = self.wait_for_prediction(prediction_id).await?;
let output = completed_prediction
.get("output")
.ok_or_else(|| ProviderError::Parsing("No output in prediction".to_string()))?;
let content = match output {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<&str>>()
.join(""),
_ => output.to_string(),
};
Ok(ChatCompletionResponse {
id: format!("chatcmpl-replicate-{}", prediction_id),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: request.model,
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: MessageRole::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Some(Usage::default()), system_fingerprint: None,
})
}
async fn completion(
&self,
request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
debug!("Replicate completion for model: {}", request.model);
let mut input = json!({
"prompt": request.prompt
});
if let Some(max_tokens) = request.max_tokens {
input["max_new_tokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
input["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
input["top_p"] = json!(top_p);
}
let prediction = self.create_prediction(&request.model, input).await?;
let prediction_id = prediction
.get("id")
.and_then(|id| id.as_str())
.ok_or_else(|| ProviderError::Parsing("No prediction ID in response".to_string()))?;
let completed_prediction = self.wait_for_prediction(prediction_id).await?;
let output = completed_prediction
.get("output")
.ok_or_else(|| ProviderError::Parsing("No output in prediction".to_string()))?;
let text = match output {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<&str>>()
.join(""),
_ => output.to_string(),
};
Ok(CompletionResponse {
id: format!("cmpl-replicate-{}", prediction_id),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: request.model,
choices: vec![CompletionChoice {
text,
index: 0,
logprobs: None,
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage::default()),
})
}
async fn embedding(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
Err(
ProviderError::InvalidRequest("Embeddings not supported by Replicate".to_string())
.into(),
)
}
async fn image_generation(
&self,
request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
debug!("Replicate image generation for model: {:?}", request.model);
let mut input = json!({
"prompt": request.prompt
});
if let Some(n) = request.n {
input["num_outputs"] = json!(n);
}
if let Some(size) = &request.size {
if let Some((width, height)) = size.split_once('x') {
if let (Ok(w), Ok(h)) = (width.parse::<u32>(), height.parse::<u32>()) {
input["width"] = json!(w);
input["height"] = json!(h);
}
}
}
let model_str = request.model.as_ref().ok_or_else(|| {
GatewayError::InvalidRequest("Model is required for image generation".to_string())
})?;
let prediction = self.create_prediction(model_str, input).await?;
let prediction_id = prediction
.get("id")
.and_then(|id| id.as_str())
.ok_or_else(|| ProviderError::Parsing("No prediction ID in response".to_string()))?;
let completed_prediction = self.wait_for_prediction(prediction_id).await?;
let output = completed_prediction
.get("output")
.ok_or_else(|| ProviderError::Parsing("No output in prediction".to_string()))?;
let urls = match output {
serde_json::Value::String(url) => vec![url.clone()],
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.map(|s| s.to_string())
.collect(),
_ => return Err(ProviderError::Parsing("Invalid output format".to_string()).into()),
};
let data = urls
.into_iter()
.map(|url| ImageObject {
url: Some(url),
b64_json: None,
})
.collect();
Ok(ImageGenerationResponse {
created: chrono::Utc::now().timestamp() as u64,
data,
})
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.001,
output_cost_per_1k: 0.002,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}