use async_trait::async_trait;
use genai::chat::{ChatMessage, ChatRequest};
use genai::resolver::{AuthData, AuthResolver, Endpoint, ServiceTargetResolver};
use genai::{Client, ServiceTarget};
use crate::summarizer::backend::{CompactMode, CompactOpts, SummarizerBackend};
use crate::summarizer::error::BackendError;
use crate::summarizer::prompts::render_abstractive;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProviderKind {
OpenAi,
Anthropic,
Gemini,
XAi,
Groq,
DeepSeek,
Together,
Fireworks,
OpenAiCompat,
}
impl ProviderKind {
pub fn parse(s: &str) -> Result<Self, String> {
match s {
"openai" => Ok(ProviderKind::OpenAi),
"anthropic" => Ok(ProviderKind::Anthropic),
"gemini" => Ok(ProviderKind::Gemini),
"xai" => Ok(ProviderKind::XAi),
"groq" => Ok(ProviderKind::Groq),
"deepseek" => Ok(ProviderKind::DeepSeek),
"together" => Ok(ProviderKind::Together),
"fireworks" => Ok(ProviderKind::Fireworks),
"openai_compat" => Ok(ProviderKind::OpenAiCompat),
other => Err(format!("unknown provider: {other}")),
}
}
}
pub fn build_client(
provider: ProviderKind,
base_url: Option<&str>,
api_key: Option<&str>,
) -> Result<Client, String> {
let mut builder = Client::builder();
if provider == ProviderKind::OpenAiCompat {
let base = normalize_openai_compat_base_url(
base_url.ok_or_else(|| "openai_compat requires base_url".to_string())?,
);
let key_for_resolver = api_key.unwrap_or("noop").to_string();
let resolver = ServiceTargetResolver::from_resolver_fn(
move |service_target: ServiceTarget| -> Result<ServiceTarget, genai::resolver::Error> {
let mut model = service_target.model;
model.adapter_kind = genai::adapter::AdapterKind::OpenAI;
Ok(ServiceTarget {
endpoint: Endpoint::from_owned(base.clone()),
auth: AuthData::from_single(key_for_resolver.clone()),
model,
})
},
);
builder = builder.with_service_target_resolver(resolver);
} else if let Some(k) = api_key {
let k = k.to_string();
builder = builder.with_auth_resolver(AuthResolver::from_resolver_fn(
move |_| -> Result<Option<AuthData>, genai::resolver::Error> {
Ok(Some(AuthData::from_single(k.clone())))
},
));
}
Ok(builder.build())
}
fn normalize_openai_compat_base_url(base: &str) -> String {
let trimmed = base.trim();
let with_slash = if trimmed.ends_with('/') {
trimmed.to_string()
} else {
format!("{trimmed}/")
};
if with_slash.ends_with("/v1/") {
return with_slash;
}
format!("{with_slash}v1/")
}
pub fn resolve_request_model(_provider: ProviderKind, model: &str) -> String {
model.to_string()
}
#[cfg(test)]
mod provider_tests {
use super::*;
#[test]
fn parses_every_supported_provider() {
for s in [
"openai",
"anthropic",
"gemini",
"xai",
"groq",
"deepseek",
"together",
"fireworks",
"openai_compat",
] {
assert!(ProviderKind::parse(s).is_ok(), "unexpected failure for {s}");
}
}
#[test]
fn rejects_unknown_provider() {
assert!(ProviderKind::parse("bogus").is_err());
}
}
#[derive(Debug, Clone)]
pub struct CloudBackend {
name: String,
model: String,
client: Client,
}
impl CloudBackend {
pub fn new(
name: impl Into<String>,
provider: ProviderKind,
model: impl Into<String>,
base_url: Option<String>,
api_key: Option<String>,
) -> Result<Self, BackendError> {
let name = name.into();
let model = model.into();
let client = build_client(provider, base_url.as_deref(), api_key.as_deref())
.map_err(BackendError::Invalid)?;
Ok(Self {
name,
model,
client,
})
}
fn build_request(&self, content: &str, opts: &CompactOpts) -> ChatRequest {
let parts = render_abstractive(opts, content);
ChatRequest::new(vec![
ChatMessage::system(parts.system),
ChatMessage::user(parts.user),
])
}
fn map_error(err: genai::Error) -> BackendError {
use genai::Error::{
ChatReqHasNoMessages, LastChatMessageIsNotUser, MessageContentTypeNotSupported,
MessageRoleNotSupported, NoAuthData, NoAuthResolver, RequiresApiKey, WebAdapterCall,
WebModelCall,
};
use genai::webc::Error::ResponseFailedStatus;
match &err {
WebModelCall {
webc_error: ResponseFailedStatus { status, .. },
..
}
| WebAdapterCall {
webc_error: ResponseFailedStatus { status, .. },
..
} => {
if status.as_u16() == 429 {
BackendError::RateLimited
} else if matches!(status.as_u16(), 401 | 403) {
BackendError::AuthFailed(err.to_string())
} else if status.is_client_error() {
BackendError::ModelError(err.to_string())
} else {
BackendError::Unavailable(err.to_string())
}
}
RequiresApiKey { .. } | NoAuthResolver { .. } | NoAuthData { .. } => {
BackendError::AuthFailed(err.to_string())
}
ChatReqHasNoMessages { .. }
| LastChatMessageIsNotUser { .. }
| MessageRoleNotSupported { .. }
| MessageContentTypeNotSupported { .. } => BackendError::Invalid(err.to_string()),
_ => BackendError::Unavailable(err.to_string()),
}
}
}
#[async_trait]
impl SummarizerBackend for CloudBackend {
async fn compact(&self, content: &str, opts: &CompactOpts) -> Result<String, BackendError> {
if content.trim().is_empty() {
return Err(BackendError::Invalid("empty content".to_string()));
}
if opts.mode != CompactMode::Abstractive {
tracing::warn!(
target: "rover::summarizer",
mode = opts.mode.as_str(),
backend = self.name,
"cloud backend invoked for non-abstractive mode",
);
}
let req = self.build_request(content, opts);
let resp = self
.client
.exec_chat(&self.model, req, None)
.await
.map_err(Self::map_error)?;
Ok(resp.first_text().unwrap_or_default().to_string())
}
fn name(&self) -> &str {
&self.name
}
fn model_id(&self) -> &str {
&self.model
}
fn uses_model_prompt(&self) -> bool {
true
}
}
#[cfg(test)]
mod cloud_tests {
use super::*;
use crate::summarizer::backend::{CompactMode, PreserveSection, Style};
fn opts() -> CompactOpts {
CompactOpts {
mode: CompactMode::Abstractive,
style: Style::Prose,
target_tokens: Some(200),
focus: None,
preserve: vec![],
backend_name: "fast".to_string(),
}
}
#[test]
fn build_request_has_two_messages() {
let be = CloudBackend::new(
"fast",
ProviderKind::OpenAi,
"gpt-4o-mini",
None,
Some("noop".into()),
)
.unwrap();
let req = be.build_request("hello", &opts());
assert_eq!(req.messages.len(), 2);
}
#[test]
fn openai_compat_requires_base_url() {
let r = CloudBackend::new("custom", ProviderKind::OpenAiCompat, "m", None, None);
assert!(matches!(r, Err(BackendError::Invalid(_))));
}
#[test]
fn openai_compat_constructs_with_base_url() {
let r = CloudBackend::new(
"custom",
ProviderKind::OpenAiCompat,
"m",
Some("http://127.0.0.1:1234/v1".into()),
Some("k".into()),
);
assert!(r.is_ok());
}
#[test]
fn preserve_optional_field_round_trips() {
let _ = vec![PreserveSection::Code];
}
}
#[cfg(test)]
mod normalize_tests {
use super::normalize_openai_compat_base_url;
#[test]
fn appends_v1_slash_when_missing() {
assert_eq!(
normalize_openai_compat_base_url("http://localhost:1234"),
"http://localhost:1234/v1/"
);
assert_eq!(
normalize_openai_compat_base_url("http://localhost:1234/"),
"http://localhost:1234/v1/"
);
assert_eq!(
normalize_openai_compat_base_url("http://localhost:1234/v1"),
"http://localhost:1234/v1/"
);
}
#[test]
fn idempotent_on_already_normalized() {
let already = "http://localhost:1234/v1/";
assert_eq!(normalize_openai_compat_base_url(already), already);
}
#[test]
fn leaves_custom_paths_with_v1_alone() {
assert_eq!(
normalize_openai_compat_base_url("https://api.example.com/custom/v1/"),
"https://api.example.com/custom/v1/"
);
}
#[test]
fn appends_v1_to_custom_paths_without_v1() {
assert_eq!(
normalize_openai_compat_base_url("https://api.example.com/custom/"),
"https://api.example.com/custom/v1/"
);
}
#[test]
fn trims_whitespace() {
assert_eq!(
normalize_openai_compat_base_url(" http://localhost:1234 "),
"http://localhost:1234/v1/"
);
}
}