Skip to main content

crabtalk_model/
provider.rs

1//! Provider implementation backed by crabllm-provider.
2//!
3//! Wraps `crabllm_provider::Provider` behind wcore's `Model` trait with
4//! type conversion and retry logic.
5
6use crate::{config::ProviderDef, convert};
7use anyhow::Result;
8use async_stream::try_stream;
9use crabllm_core::ApiError;
10use crabllm_provider::Provider as CtProvider;
11use futures_core::Stream;
12use futures_util::StreamExt;
13use rand::Rng;
14use std::time::Duration;
15use wcore::model::{Model, Response, StreamChunk};
16
17/// Unified LLM provider wrapping a crabtalk provider instance.
18#[derive(Clone)]
19pub struct Provider {
20    inner: CtProvider,
21    client: reqwest::Client,
22    model: String,
23    max_retries: u32,
24    timeout: Duration,
25}
26
27impl Provider {
28    /// Get the model name this provider was constructed for.
29    pub fn model_name(&self) -> &String {
30        &self.model
31    }
32}
33
34/// Strip known endpoint suffixes so both bare origins and full paths work.
35fn normalize_base_url(url: &str) -> String {
36    let url = url.trim_end_matches('/');
37    for suffix in ["/chat/completions", "/messages", "/embeddings"] {
38        if let Some(stripped) = url.strip_suffix(suffix) {
39            return stripped.to_string();
40        }
41    }
42    url.to_string()
43}
44
45/// Construct a `Provider` from a provider definition and model name.
46pub fn build_provider(def: &ProviderDef, model: &str, client: reqwest::Client) -> Result<Provider> {
47    let mut config = def.clone();
48    config.kind = config.effective_kind();
49    let mut inner = CtProvider::from(&config);
50
51    // Apply crabtalk-specific base_url normalization (strip endpoint suffixes).
52    if let CtProvider::OpenAiCompat {
53        ref mut base_url, ..
54    } = inner
55    {
56        *base_url = normalize_base_url(base_url);
57    }
58
59    Ok(Provider {
60        inner,
61        client,
62        model: model.to_owned(),
63        max_retries: def.max_retries.unwrap_or(2),
64        timeout: Duration::from_secs(def.timeout.unwrap_or(30)),
65    })
66}
67
68impl Model for Provider {
69    async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
70        let mut ct_req = convert::to_ct_request(request);
71        ct_req.stream = Some(false);
72        send_with_retry(
73            &self.inner,
74            &self.client,
75            &ct_req,
76            self.max_retries,
77            self.timeout,
78        )
79        .await
80    }
81
82    fn stream(
83        &self,
84        request: wcore::model::Request,
85    ) -> impl Stream<Item = Result<StreamChunk>> + Send {
86        let inner = self.inner.clone();
87        let client = self.client.clone();
88        let timeout = self.timeout;
89        try_stream! {
90            let mut ct_req = convert::to_ct_request(&request);
91            ct_req.stream = Some(true);
92
93            let boxed = tokio::time::timeout(timeout, inner.chat_completion_stream(&client, &ct_req))
94                .await
95                .map_err(|_| anyhow::anyhow!("stream connection timed out"))?
96                .map_err(format_provider_error)?;
97
98            let mut stream = std::pin::pin!(boxed);
99            while let Some(chunk) = stream.next().await {
100                let ct_chunk = chunk.map_err(format_provider_error)?;
101                yield convert::from_ct_chunk(ct_chunk);
102            }
103        }
104    }
105
106    fn context_limit(&self, model: &str) -> usize {
107        wcore::model::default_context_limit(model)
108    }
109
110    fn active_model(&self) -> String {
111        self.model.clone()
112    }
113}
114
115/// Send a non-streaming request with exponential backoff retry on transient errors.
116async fn send_with_retry(
117    provider: &CtProvider,
118    client: &reqwest::Client,
119    request: &crabllm_core::ChatCompletionRequest,
120    max_retries: u32,
121    timeout: Duration,
122) -> Result<Response> {
123    let mut backoff = Duration::from_millis(100);
124    let mut last_err = None;
125
126    for _ in 0..=max_retries {
127        let result = if timeout.is_zero() {
128            provider.chat_completion(client, request).await
129        } else {
130            tokio::time::timeout(timeout, provider.chat_completion(client, request))
131                .await
132                .map_err(|_| crabllm_core::Error::Timeout)?
133        };
134
135        match result {
136            Ok(resp) => return Ok(convert::from_ct_response(resp)),
137            Err(e) if e.is_transient() => {
138                last_err = Some(e);
139                let jitter = jittered(backoff);
140                tokio::time::sleep(jitter).await;
141                backoff *= 2;
142            }
143            Err(e) => return Err(format_provider_error(e)),
144        }
145    }
146
147    Err(format_provider_error(last_err.unwrap()))
148}
149
150/// Full jitter: random duration in [backoff/2, backoff].
151fn jittered(backoff: Duration) -> Duration {
152    let lo = backoff.as_millis() as u64 / 2;
153    let hi = backoff.as_millis() as u64;
154    if lo >= hi {
155        return backoff;
156    }
157    Duration::from_millis(rand::rng().random_range(lo..=hi))
158}
159
160/// Convert a crabllm error into an anyhow error with a human-readable message.
161///
162/// For provider HTTP errors, attempts to parse the response body as an
163/// OpenAI-compatible API error and extract the `message` field.
164fn format_provider_error(e: crabllm_core::Error) -> anyhow::Error {
165    match e {
166        crabllm_core::Error::Provider { status, body } => {
167            let msg = serde_json::from_str::<ApiError>(&body)
168                .map(|api_err| api_err.error.message)
169                .unwrap_or_else(|_| truncate(&body, 200));
170            anyhow::anyhow!("provider error (HTTP {status}): {msg}")
171        }
172        other => anyhow::anyhow!("{other}"),
173    }
174}
175
176fn truncate(s: &str, max: usize) -> String {
177    match s.char_indices().nth(max) {
178        Some((i, _)) => format!("{}...", &s[..i]),
179        None => s.to_string(),
180    }
181}