Skip to main content

rover/summarizer/
cloud.rs

1//! Cloud-backed summarizer wrapping `genai::Client`.
2//!
3//! Supports every provider `genai` ships natively (OpenAI, Anthropic,
4//! Gemini, xAI, Groq, DeepSeek, Together, Fireworks) plus a custom
5//! `openai_compat` kind that points at any OpenAI-compatible endpoint
6//! via a `ServiceTargetResolver`.
7
8use async_trait::async_trait;
9use genai::chat::{ChatMessage, ChatRequest};
10use genai::resolver::{AuthData, AuthResolver, Endpoint, ServiceTargetResolver};
11use genai::{Client, ServiceTarget};
12
13use crate::summarizer::backend::{CompactMode, CompactOpts, SummarizerBackend};
14use crate::summarizer::error::BackendError;
15use crate::summarizer::prompts::render_abstractive;
16
17/// Provider kind parsed from `[backends.<name>] provider = "..."`.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum ProviderKind {
20    OpenAi,
21    Anthropic,
22    Gemini,
23    XAi,
24    Groq,
25    DeepSeek,
26    Together,
27    Fireworks,
28    /// Custom base_url speaking the OpenAI Chat Completions shape.
29    OpenAiCompat,
30}
31
32impl ProviderKind {
33    // Consumed in Task 6 (registry) when mapping `[backends.<name>] provider = "..."`.
34    pub fn parse(s: &str) -> Result<Self, String> {
35        match s {
36            "openai" => Ok(ProviderKind::OpenAi),
37            "anthropic" => Ok(ProviderKind::Anthropic),
38            "gemini" => Ok(ProviderKind::Gemini),
39            "xai" => Ok(ProviderKind::XAi),
40            "groq" => Ok(ProviderKind::Groq),
41            "deepseek" => Ok(ProviderKind::DeepSeek),
42            "together" => Ok(ProviderKind::Together),
43            "fireworks" => Ok(ProviderKind::Fireworks),
44            "openai_compat" => Ok(ProviderKind::OpenAiCompat),
45            other => Err(format!("unknown provider: {other}")),
46        }
47    }
48}
49
50/// Build a `genai::Client` configured for the given provider.
51///
52/// For `OpenAiCompat`, a `ServiceTargetResolver` is installed that rewrites
53/// any request for `model` to `base_url` using the OpenAI wire shape.  All
54/// other providers use genai's built-in env-var key resolution unless
55/// `api_key` is supplied.
56///
57/// Shared by `CloudBackend` (summarizer) and `CloudCaptioner` (vlm) so that
58/// provider-resolution logic lives in exactly one place.
59pub fn build_client(
60    provider: ProviderKind,
61    base_url: Option<&str>,
62    api_key: Option<&str>,
63) -> Result<Client, String> {
64    let mut builder = Client::builder();
65
66    if provider == ProviderKind::OpenAiCompat {
67        let base = normalize_openai_compat_base_url(
68            base_url.ok_or_else(|| "openai_compat requires base_url".to_string())?,
69        );
70        let key_for_resolver = api_key.unwrap_or("noop").to_string();
71        let resolver = ServiceTargetResolver::from_resolver_fn(
72            move |service_target: ServiceTarget| -> Result<ServiceTarget, genai::resolver::Error> {
73                // Force any request through the configured base_url using the
74                // OpenAI Chat Completions wire shape.
75                let mut model = service_target.model;
76                model.adapter_kind = genai::adapter::AdapterKind::OpenAI;
77                Ok(ServiceTarget {
78                    endpoint: Endpoint::from_owned(base.clone()),
79                    auth: AuthData::from_single(key_for_resolver.clone()),
80                    model,
81                })
82            },
83        );
84        builder = builder.with_service_target_resolver(resolver);
85    } else if let Some(k) = api_key {
86        let k = k.to_string();
87        builder = builder.with_auth_resolver(AuthResolver::from_resolver_fn(
88            move |_| -> Result<Option<AuthData>, genai::resolver::Error> {
89                Ok(Some(AuthData::from_single(k.clone())))
90            },
91        ));
92    }
93
94    Ok(builder.build())
95}
96
97/// Normalize a user-supplied openai_compat base URL so it ends with `/v1/`.
98/// Accepts inputs missing the trailing slash, missing the `/v1/` segment, or
99/// already-correct. Idempotent.
100///
101/// Examples:
102/// - `http://localhost:1234`        → `http://localhost:1234/v1/`
103/// - `http://localhost:1234/`       → `http://localhost:1234/v1/`
104/// - `http://localhost:1234/v1`     → `http://localhost:1234/v1/`
105/// - `http://localhost:1234/v1/`    → unchanged
106/// - `https://api.example.com/custom/v1/` → unchanged
107/// - `https://api.example.com/custom/`    → `https://api.example.com/custom/v1/`
108fn normalize_openai_compat_base_url(base: &str) -> String {
109    let trimmed = base.trim();
110    let with_slash = if trimmed.ends_with('/') {
111        trimmed.to_string()
112    } else {
113        format!("{trimmed}/")
114    };
115    if with_slash.ends_with("/v1/") {
116        return with_slash;
117    }
118    format!("{with_slash}v1/")
119}
120
121/// Return the model name string to pass to `exec_chat`.
122///
123/// For all currently-supported providers the model name is passed through
124/// verbatim.  This shim exists so that `CloudCaptioner` and `CloudBackend`
125/// have a single stable call site to update if a provider ever requires a
126/// prefix (e.g. `"models/gemini-pro-vision"`).
127pub fn resolve_request_model(_provider: ProviderKind, model: &str) -> String {
128    model.to_string()
129}
130
131#[cfg(test)]
132mod provider_tests {
133    use super::*;
134
135    #[test]
136    fn parses_every_supported_provider() {
137        for s in [
138            "openai",
139            "anthropic",
140            "gemini",
141            "xai",
142            "groq",
143            "deepseek",
144            "together",
145            "fireworks",
146            "openai_compat",
147        ] {
148            assert!(ProviderKind::parse(s).is_ok(), "unexpected failure for {s}");
149        }
150    }
151
152    #[test]
153    fn rejects_unknown_provider() {
154        assert!(ProviderKind::parse("bogus").is_err());
155    }
156}
157
158/// Cloud backend. Builds a `genai::Client` once at construction; the
159/// service holds an `Arc<dyn SummarizerBackend>` so this struct is
160/// cheap to clone.
161#[derive(Debug, Clone)]
162pub struct CloudBackend {
163    name: String,
164    model: String,
165    client: Client,
166}
167
168impl CloudBackend {
169    /// Build a cloud backend.
170    ///
171    /// * `name` — config-key name (e.g. "fast").
172    /// * `provider` — parsed provider kind.
173    /// * `model` — the literal model id passed to genai (e.g. "gpt-4o-mini").
174    /// * `base_url` — only used when `provider == OpenAiCompat`. For native
175    ///   providers, pass `None`.
176    /// * `api_key` — when `Some`, installs an explicit auth override. When
177    ///   `None`, genai's default env-var resolution applies (OPENAI_API_KEY,
178    ///   ANTHROPIC_API_KEY, etc.).
179    // Consumed in Task 6 (registry) when constructing backends from config.
180    pub fn new(
181        name: impl Into<String>,
182        provider: ProviderKind,
183        model: impl Into<String>,
184        base_url: Option<String>,
185        api_key: Option<String>,
186    ) -> Result<Self, BackendError> {
187        let name = name.into();
188        let model = model.into();
189
190        let client = build_client(provider, base_url.as_deref(), api_key.as_deref())
191            .map_err(BackendError::Invalid)?;
192
193        Ok(Self {
194            name,
195            model,
196            client,
197        })
198    }
199
200    fn build_request(&self, content: &str, opts: &CompactOpts) -> ChatRequest {
201        let parts = render_abstractive(opts, content);
202        ChatRequest::new(vec![
203            ChatMessage::system(parts.system),
204            ChatMessage::user(parts.user),
205        ])
206    }
207
208    /// Translate a genai error into our error type by matching on
209    /// `genai::Error`'s structural variants. HTTP status codes come
210    /// straight out of `webc::Error::ResponseFailedStatus`; genai's own
211    /// request-validation variants map to `Invalid`/`AuthFailed`.
212    fn map_error(err: genai::Error) -> BackendError {
213        use genai::Error::{
214            ChatReqHasNoMessages, LastChatMessageIsNotUser, MessageContentTypeNotSupported,
215            MessageRoleNotSupported, NoAuthData, NoAuthResolver, RequiresApiKey, WebAdapterCall,
216            WebModelCall,
217        };
218        use genai::webc::Error::ResponseFailedStatus;
219
220        match &err {
221            WebModelCall {
222                webc_error: ResponseFailedStatus { status, .. },
223                ..
224            }
225            | WebAdapterCall {
226                webc_error: ResponseFailedStatus { status, .. },
227                ..
228            } => {
229                if status.as_u16() == 429 {
230                    BackendError::RateLimited
231                } else if matches!(status.as_u16(), 401 | 403) {
232                    BackendError::AuthFailed(err.to_string())
233                } else if status.is_client_error() {
234                    BackendError::ModelError(err.to_string())
235                } else {
236                    BackendError::Unavailable(err.to_string())
237                }
238            }
239            RequiresApiKey { .. } | NoAuthResolver { .. } | NoAuthData { .. } => {
240                BackendError::AuthFailed(err.to_string())
241            }
242            ChatReqHasNoMessages { .. }
243            | LastChatMessageIsNotUser { .. }
244            | MessageRoleNotSupported { .. }
245            | MessageContentTypeNotSupported { .. } => BackendError::Invalid(err.to_string()),
246            _ => BackendError::Unavailable(err.to_string()),
247        }
248    }
249}
250
251#[async_trait]
252impl SummarizerBackend for CloudBackend {
253    async fn compact(&self, content: &str, opts: &CompactOpts) -> Result<String, BackendError> {
254        if content.trim().is_empty() {
255            return Err(BackendError::Invalid("empty content".to_string()));
256        }
257        // Only Abstractive uses the cloud round-trip; Extractive and
258        // Headlines belong to the extractive backend. If a caller asks
259        // a cloud backend for Extractive output, we still send the
260        // chat request — the abstractive prompt produces extractive-style
261        // output well enough — but log a warning so this misuse is visible.
262        if opts.mode != CompactMode::Abstractive {
263            tracing::warn!(
264                target: "rover::summarizer",
265                mode = opts.mode.as_str(),
266                backend = self.name,
267                "cloud backend invoked for non-abstractive mode",
268            );
269        }
270        let req = self.build_request(content, opts);
271        let resp = self
272            .client
273            .exec_chat(&self.model, req, None)
274            .await
275            .map_err(Self::map_error)?;
276        Ok(resp.first_text().unwrap_or_default().to_string())
277    }
278
279    fn name(&self) -> &str {
280        &self.name
281    }
282
283    fn model_id(&self) -> &str {
284        &self.model
285    }
286
287    fn uses_model_prompt(&self) -> bool {
288        true
289    }
290}
291
292#[cfg(test)]
293mod cloud_tests {
294    use super::*;
295    use crate::summarizer::backend::{CompactMode, PreserveSection, Style};
296
297    fn opts() -> CompactOpts {
298        CompactOpts {
299            mode: CompactMode::Abstractive,
300            style: Style::Prose,
301            target_tokens: Some(200),
302            focus: None,
303            preserve: vec![],
304            backend_name: "fast".to_string(),
305        }
306    }
307
308    #[test]
309    fn build_request_has_two_messages() {
310        let be = CloudBackend::new(
311            "fast",
312            ProviderKind::OpenAi,
313            "gpt-4o-mini",
314            None,
315            Some("noop".into()),
316        )
317        .unwrap();
318        let req = be.build_request("hello", &opts());
319        // Two messages: system + user.
320        assert_eq!(req.messages.len(), 2);
321    }
322
323    #[test]
324    fn openai_compat_requires_base_url() {
325        let r = CloudBackend::new("custom", ProviderKind::OpenAiCompat, "m", None, None);
326        assert!(matches!(r, Err(BackendError::Invalid(_))));
327    }
328
329    #[test]
330    fn openai_compat_constructs_with_base_url() {
331        let r = CloudBackend::new(
332            "custom",
333            ProviderKind::OpenAiCompat,
334            "m",
335            Some("http://127.0.0.1:1234/v1".into()),
336            Some("k".into()),
337        );
338        assert!(r.is_ok());
339    }
340
341    #[test]
342    fn preserve_optional_field_round_trips() {
343        let _ = vec![PreserveSection::Code];
344    }
345}
346
347#[cfg(test)]
348mod normalize_tests {
349    use super::normalize_openai_compat_base_url;
350
351    #[test]
352    fn appends_v1_slash_when_missing() {
353        assert_eq!(
354            normalize_openai_compat_base_url("http://localhost:1234"),
355            "http://localhost:1234/v1/"
356        );
357        assert_eq!(
358            normalize_openai_compat_base_url("http://localhost:1234/"),
359            "http://localhost:1234/v1/"
360        );
361        assert_eq!(
362            normalize_openai_compat_base_url("http://localhost:1234/v1"),
363            "http://localhost:1234/v1/"
364        );
365    }
366
367    #[test]
368    fn idempotent_on_already_normalized() {
369        let already = "http://localhost:1234/v1/";
370        assert_eq!(normalize_openai_compat_base_url(already), already);
371    }
372
373    #[test]
374    fn leaves_custom_paths_with_v1_alone() {
375        assert_eq!(
376            normalize_openai_compat_base_url("https://api.example.com/custom/v1/"),
377            "https://api.example.com/custom/v1/"
378        );
379    }
380
381    #[test]
382    fn appends_v1_to_custom_paths_without_v1() {
383        assert_eq!(
384            normalize_openai_compat_base_url("https://api.example.com/custom/"),
385            "https://api.example.com/custom/v1/"
386        );
387    }
388
389    #[test]
390    fn trims_whitespace() {
391        assert_eq!(
392            normalize_openai_compat_base_url("  http://localhost:1234  "),
393            "http://localhost:1234/v1/"
394        );
395    }
396}