use anyhow::Result;
use async_stream::try_stream;
use crabllm_core::{
ApiError, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, Provider,
};
use futures_core::Stream;
use futures_util::StreamExt;
use std::sync::Arc;
pub struct Model<P: Provider + 'static> {
inner: Arc<P>,
}
impl<P: Provider + 'static> Model<P> {
pub fn new(provider: P) -> Self {
Self {
inner: Arc::new(provider),
}
}
pub fn from_arc(provider: Arc<P>) -> Self {
Self { inner: provider }
}
pub async fn send_ct(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse> {
let mut req = request;
req.stream = Some(false);
let model_label = req.model.clone();
self.inner
.chat_completion(&req)
.await
.map_err(|e| format_provider_error(&model_label, "send", e))
}
pub fn stream_ct(
&self,
request: ChatCompletionRequest,
) -> impl Stream<Item = Result<ChatCompletionChunk>> + Send + 'static {
let inner = Arc::clone(&self.inner);
let mut req = request;
req.stream = Some(true);
let model_label = req.model.clone();
try_stream! {
let mut stream = inner
.chat_completion_stream(&req)
.await
.map_err(|e| format_provider_error(&model_label, "stream open", e))?;
while let Some(chunk) = stream.next().await {
yield chunk
.map_err(|e| format_provider_error(&model_label, "stream chunk", e))?;
}
}
}
}
impl<P: Provider + 'static> Clone for Model<P> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<P: Provider + 'static> std::fmt::Debug for Model<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Model").finish()
}
}
fn format_provider_error(model: &str, op: &str, e: crabllm_core::Error) -> anyhow::Error {
match e {
crabllm_core::Error::Provider { status, body } => {
let msg = serde_json::from_str::<ApiError>(&body)
.map(|api_err| api_err.error.message)
.unwrap_or_else(|_| truncate(&body, 200));
anyhow::anyhow!("model {op} failed for '{model}' (HTTP {status}): {msg}")
}
other => anyhow::anyhow!("model {op} failed for '{model}': {other}"),
}
}
fn truncate(s: &str, max: usize) -> String {
match s.char_indices().nth(max) {
Some((i, _)) => format!("{}...", &s[..i]),
None => s.to_string(),
}
}