use http::status::StatusCode;
use rig::{
agent::{Agent, PromptHook},
client::{CompletionClient, ModelListingClient, ProviderClient, ProviderClientError},
completion::{CompletionError, CompletionModel, Prompt, PromptError},
http_client::Error as HttpError,
providers::{anthropic, cohere, gemini, ollama, openai, openrouter},
};
use serde::{Deserialize, Serialize};
use strum::EnumIter;
use strum_macros::{Display, EnumString};
use thiserror::Error;
#[derive(Debug, Clone, Deserialize, Serialize, EnumString, Display, EnumIter, PartialEq)]
#[serde(rename_all = "lowercase")]
#[strum(serialize_all = "lowercase")]
pub enum Provider {
OpenAI,
Gemini,
Anthropic,
Cohere,
Ollama,
OpenRouter,
MockAi,
}
#[derive(Debug, Error)]
pub enum AiError {
#[error("Rate limit exceeded {0}")]
RateExceeded(String),
#[error("Provider does not provide model")]
NotFound,
#[error("Provider client error: {0}")]
ProviderError(String),
#[error("Unknown error: {0}")]
Other(String),
}
#[derive(PartialEq, Debug)]
pub struct ModelEntry {
pub display: String,
pub id: String,
}
#[async_trait::async_trait]
pub trait GenerateCommitMsg {
async fn generate_commit_msg(&self, diff: &str) -> Result<String, AiError>;
}
#[async_trait::async_trait]
pub trait ListModels: Send + Sync {
async fn list_models(&self) -> anyhow::Result<Vec<ModelEntry>>;
}
#[async_trait::async_trait]
impl<M, P> GenerateCommitMsg for Agent<M, P>
where
M: CompletionModel + 'static,
P: PromptHook<M> + 'static,
{
async fn generate_commit_msg(&self, diff: &str) -> Result<String, AiError> {
Ok(self.prompt(diff).await?)
}
}
#[async_trait::async_trait]
impl<T> ListModels for T
where
T: ModelListingClient + Send + Sync,
{
async fn list_models(&self) -> anyhow::Result<Vec<ModelEntry>> {
Ok(ModelListingClient::list_models(self)
.await?
.into_iter()
.map(|m| ModelEntry {
display: format!("{} ({})", m.display_name(), m.id),
id: m.id.to_string(),
})
.collect())
}
}
impl From<PromptError> for AiError {
fn from(error: PromptError) -> Self {
match error {
PromptError::CompletionError(e) => match e {
CompletionError::HttpError(e) => match e {
HttpError::InvalidStatusCode(s) => match s {
StatusCode::TOO_MANY_REQUESTS => AiError::RateExceeded(e.to_string()),
StatusCode::NOT_FOUND => AiError::NotFound,
_ => {
dbg!(&s, s.as_u16());
AiError::Other(e.to_string())
}
},
HttpError::InvalidStatusCodeWithMessage(code, msg) => match code.as_u16() {
429 => AiError::RateExceeded(msg.to_string()),
404 => AiError::NotFound,
_ => AiError::Other(format!("{code}: {msg}")),
},
_ => {
dbg!(&e);
AiError::Other(e.to_string())
}
},
_ => AiError::Other(e.to_string()),
},
_ => AiError::Other(error.to_string()),
}
}
}
impl From<ProviderClientError> for AiError {
fn from(e: ProviderClientError) -> Self {
AiError::ProviderError(e.to_string())
}
}
pub fn build_commit_agent(
provider: Provider,
model: String,
system_message: Option<&str>,
) -> Result<Box<dyn GenerateCommitMsg>, AiError> {
let preamble = system_message.unwrap();
let agent: Box<dyn GenerateCommitMsg> = match provider {
Provider::OpenAI => Box::new(
openai::Client::from_env()?
.agent(&model)
.preamble(preamble)
.build(),
),
Provider::Gemini => Box::new(
gemini::Client::from_env()?
.agent(&model)
.preamble(preamble)
.build(),
),
Provider::Anthropic => Box::new(
anthropic::Client::from_env()?
.agent(&model)
.preamble(preamble)
.build(),
),
Provider::Cohere => Box::new(
cohere::Client::from_env()?
.agent(&model)
.preamble(preamble)
.build(),
),
Provider::Ollama => Box::new(
ollama::Client::from_env()?
.agent(&model)
.preamble(preamble)
.build(),
),
Provider::OpenRouter => Box::new(
openrouter::Client::from_env()?
.agent(&model)
.preamble(preamble)
.build(),
),
Provider::MockAi => Box::new(MockAi::default()),
};
Ok(agent)
}
pub fn build_model_listing_client(provider: Provider) -> Result<Box<dyn ListModels>, AiError> {
let client: Box<dyn ListModels> = match provider {
Provider::OpenAI => Box::new(openai::Client::from_env()?),
Provider::Gemini => Box::new(gemini::Client::from_env()?),
Provider::Anthropic => Box::new(anthropic::Client::from_env()?),
Provider::Ollama => Box::new(ollama::Client::from_env()?),
Provider::OpenRouter => Box::new(openrouter::Client::from_env()?),
Provider::MockAi => Box::new(MockAi::default()),
Provider::Cohere => {
return Err(AiError::Other(
"Cohere does not support model listing".to_string(),
));
}
};
Ok(client)
}
pub const MOCK_RESPONSE: &str = "feat: add file test.txt";
pub struct MockAi {
pub response: String,
}
impl Default for MockAi {
fn default() -> Self {
Self {
response: MOCK_RESPONSE.to_string(),
}
}
}
#[async_trait::async_trait]
impl GenerateCommitMsg for MockAi {
async fn generate_commit_msg(&self, _diff: &str) -> Result<String, AiError> {
Ok(self.response.clone())
}
}
#[async_trait::async_trait]
impl ListModels for MockAi {
async fn list_models(&self) -> anyhow::Result<Vec<ModelEntry>> {
Ok(vec![ModelEntry {
display: "mock-model (mock-1)".to_string(),
id: "mock-1".to_string(),
}])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::git::stage_files;
use crate::test_utils::setup;
use anyhow::Result;
use rstest::rstest;
#[rstest]
#[case("openai", Provider::OpenAI)]
#[case("gemini", Provider::Gemini)]
#[case("anthropic", Provider::Anthropic)]
#[case("cohere", Provider::Cohere)]
#[case("ollama", Provider::Ollama)]
#[case("mockai", Provider::MockAi)]
fn provider_from_string(#[case] input: &str, #[case] expected: Provider) {
assert_eq!(input.parse::<Provider>().unwrap(), expected)
}
#[tokio::test]
async fn test_commit_msg_gen() -> Result<()> {
let (repository, _dir) = setup()?;
stage_files(&["test.txt".to_string()], &repository)?;
let diff = crate::git::get_diff(&repository)?.expect("diff should exist");
let agent = MockAi::default();
let msg = agent.generate_commit_msg(&diff).await?;
assert_eq!(msg, MOCK_RESPONSE);
Ok(())
}
#[tokio::test]
async fn test_model_listing() -> Result<()> {
let agent = MockAi::default();
let models = agent.list_models().await?;
assert_eq!(
models,
vec![ModelEntry {
display: "mock-model (mock-1)".to_string(),
id: "mock-1".to_string(),
}]
);
Ok(())
}
#[rstest]
#[case(Provider::Gemini, "GEMINI_API_KEY")]
#[case(Provider::OpenAI, "OPENAI_API_KEY")]
#[case(Provider::Anthropic, "ANTHROPIC_API_KEY")]
#[case(Provider::OpenRouter, "OPENROUTER_API_KEY")]
fn build_model_listing_client_works(#[case] provider: Provider, #[case] env_key: &str) {
unsafe { std::env::set_var(env_key, "fake-key") };
let result = build_model_listing_client(provider);
assert!(result.is_ok());
}
}