use std::{env, fmt::Debug, time::Duration};
use color_eyre::eyre::{Context, ContextCompat, eyre};
use reqwest::{
Client, ClientBuilder, RequestBuilder, Response, StatusCode,
header::{self, HeaderMap, HeaderName, HeaderValue},
};
use schemars::{JsonSchema, Schema, schema_for};
use serde::{Deserialize, de::DeserializeOwned};
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use crate::{
config::AiModelConfig,
errors::{AppError, Result, UserFacingError},
};
mod anthropic;
mod gemini;
mod ollama;
mod openai;
pub trait AiProviderBase: Send + Sync {
fn provider_name(&self) -> &'static str;
fn auth_header(&self, api_key: String) -> (HeaderName, String);
fn api_key_env_var_name(&self) -> &str;
fn build_request(
&self,
client: &Client,
sys_prompt: &str,
user_prompt: &str,
json_schema: &Schema,
) -> RequestBuilder;
}
#[trait_variant::make(Send)]
pub trait AiProvider: AiProviderBase {
async fn parse_response<T>(&self, res: Response) -> Result<T>
where
Self: Sized,
T: DeserializeOwned + JsonSchema + Debug;
}
#[cfg_attr(test, derive(Debug))]
pub struct AiClient<'a> {
inner: Client,
primary_alias: &'a str,
primary: &'a AiModelConfig,
fallback_alias: &'a str,
fallback: Option<&'a AiModelConfig>,
}
impl<'a> AiClient<'a> {
pub fn new(
primary_alias: &'a str,
primary: &'a AiModelConfig,
fallback_alias: &'a str,
fallback: Option<&'a AiModelConfig>,
) -> Result<Self> {
let mut headers = HeaderMap::new();
headers.append(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
let inner = ClientBuilder::new()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(5 * 60))
.user_agent("intelli-shell")
.default_headers(headers)
.build()
.wrap_err("Couldn't build AI client")?;
Ok(AiClient {
inner,
primary_alias,
primary,
fallback_alias,
fallback,
})
}
#[instrument(skip_all)]
pub async fn generate_command_suggestions(
&self,
sys_prompt: &str,
user_prompt: &str,
cancellation_token: CancellationToken,
) -> Result<CommandSuggestions> {
self.generate_content(sys_prompt, user_prompt, cancellation_token).await
}
#[instrument(skip_all)]
pub async fn generate_command_fix(
&self,
sys_prompt: &str,
user_prompt: &str,
cancellation_token: CancellationToken,
) -> Result<CommandFix> {
self.generate_content(sys_prompt, user_prompt, cancellation_token).await
}
#[instrument(skip_all)]
pub async fn generate_completion_suggestion(
&self,
sys_prompt: &str,
user_prompt: &str,
cancellation_token: CancellationToken,
) -> Result<VariableCompletionSuggestion> {
self.generate_content(sys_prompt, user_prompt, cancellation_token).await
}
async fn generate_content<T>(
&self,
sys_prompt: &str,
user_prompt: &str,
cancellation_token: CancellationToken,
) -> Result<T>
where
T: DeserializeOwned + JsonSchema + Debug,
{
let primary_result = self
.execute_request(self.primary, sys_prompt, user_prompt, cancellation_token.clone())
.await;
if let Err(AppError::UserFacing(UserFacingError::AiRateLimit)) = &primary_result {
if let Some(fallback) = self.fallback {
tracing::warn!(
"Primary model ({}) rate-limited, retrying with fallback ({})",
self.primary_alias,
self.fallback_alias
);
return self
.execute_request(fallback, sys_prompt, user_prompt, cancellation_token)
.await;
}
}
if let Err(AppError::UserFacing(UserFacingError::AiUnavailable)) = &primary_result {
if let Some(fallback) = self.fallback {
tracing::warn!(
"Primary model ({}) unavailable, retrying with fallback ({})",
self.primary_alias,
self.fallback_alias
);
return self
.execute_request(fallback, sys_prompt, user_prompt, cancellation_token)
.await;
}
}
primary_result
}
#[instrument(skip_all, fields(provider = config.provider().provider_name()))]
async fn execute_request<T>(
&self,
config: &AiModelConfig,
sys_prompt: &str,
user_prompt: &str,
cancellation_token: CancellationToken,
) -> Result<T>
where
T: DeserializeOwned + JsonSchema + Debug,
{
let provider = config.provider();
let json_schema = build_json_schema_for::<T>()?;
let mut req_builder = provider.build_request(&self.inner, sys_prompt, user_prompt, &json_schema);
let api_key_env = provider.api_key_env_var_name();
if let Ok(api_key) = env::var(api_key_env) {
let (header_name, header_value) = provider.auth_header(api_key);
let mut header_value =
HeaderValue::from_str(&header_value).wrap_err_with(|| format!("Invalid '{api_key_env}' value"))?;
header_value.set_sensitive(true);
req_builder = req_builder.header(header_name, header_value);
}
let req = req_builder.build().wrap_err("Couldn't build api request")?;
tracing::debug!("Calling {} API: {}", provider.provider_name(), req.url());
let res = tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
return Err(UserFacingError::Cancelled.into());
}
res = self.inner.execute(req) => {
res.map_err(|err| {
if err.is_timeout() {
tracing::error!("Request timeout: {err:?}");
UserFacingError::AiRequestTimeout
} else if err.is_connect() {
tracing::error!("Couldn't connect to the API: {err:?}");
UserFacingError::AiRequestFailed(String::from("error connecting to the provider"))
} else {
tracing::error!("Couldn't perform the request: {err:?}");
UserFacingError::AiRequestFailed(err.to_string())
}
})?
}
};
if !res.status().is_success() {
let status = res.status();
let status_str = status.as_str();
let body = res.text().await.unwrap_or_default();
if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
tracing::warn!(
"Got response [{status_str}] {}",
status.canonical_reason().unwrap_or_default()
);
tracing::debug!("{body}");
return Err(
UserFacingError::AiMissingOrInvalidApiKey(provider.api_key_env_var_name().to_string()).into(),
);
} else if status == StatusCode::TOO_MANY_REQUESTS {
tracing::info!("Got response [{status_str}] Too Many Requests");
tracing::debug!("{body}");
return Err(UserFacingError::AiRateLimit.into());
} else if status == StatusCode::SERVICE_UNAVAILABLE {
tracing::info!("Got response [{status_str}] Service Unavailable");
tracing::debug!("{body}");
return Err(UserFacingError::AiUnavailable.into());
} else if status == StatusCode::BAD_REQUEST {
tracing::error!("Got response [{status_str}] Bad Request:\n{body}");
return Err(eyre!("Bad request while fetching {} API:\n{body}", provider.provider_name()).into());
} else if let Some(reason) = status.canonical_reason() {
tracing::error!("Got response [{status_str}] {reason}:\n{body}");
return Err(
UserFacingError::AiRequestFailed(format!("received {status_str} {reason} response")).into(),
);
} else {
tracing::error!("Got response [{status_str}]:\n{body}");
return Err(UserFacingError::AiRequestFailed(format!("received {status_str} response")).into());
}
}
match &config {
AiModelConfig::Openai(conf) => conf.parse_response(res).await,
AiModelConfig::Gemini(conf) => conf.parse_response(res).await,
AiModelConfig::Anthropic(conf) => conf.parse_response(res).await,
AiModelConfig::Ollama(conf) => conf.parse_response(res).await,
}
}
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct CommandSuggestions {
pub suggestions: Vec<CommandSuggestion>,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct CommandSuggestion {
pub description: String,
pub command: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct CommandFix {
pub summary: String,
pub diagnosis: String,
pub proposal: String,
pub fixed_command: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct VariableCompletionSuggestion {
pub command: String,
}
fn build_json_schema_for<T: JsonSchema>() -> Result<Schema> {
let mut schema = schema_for!(T);
let root = schema.as_object_mut().wrap_err("The type must be an object")?;
root.insert("additionalProperties".into(), false.into());
if let Some(defs) = root.get_mut("$defs") {
for definition in defs.as_object_mut().wrap_err("Expected objects at $defs")?.values_mut() {
if let Some(def_obj) = definition.as_object_mut()
&& def_obj.get("type").and_then(|t| t.as_str()) == Some("object")
{
def_obj.insert("additionalProperties".into(), false.into());
}
}
}
Ok(schema)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_suggestions_schema() {
let schema = build_json_schema_for::<CommandSuggestions>().unwrap();
println!("{}", serde_json::to_string_pretty(&schema).unwrap());
}
#[test]
fn test_command_fix_schema() {
let schema = build_json_schema_for::<CommandFix>().unwrap();
println!("{}", serde_json::to_string_pretty(&schema).unwrap());
}
}