1use std::{env, fmt::Debug, time::Duration};
2
3use color_eyre::eyre::{Context, ContextCompat, eyre};
4use reqwest::{
5 Client, ClientBuilder, RequestBuilder, Response, StatusCode,
6 header::{self, HeaderMap, HeaderName, HeaderValue},
7};
8use schemars::{JsonSchema, Schema, schema_for};
9use serde::{Deserialize, de::DeserializeOwned};
10use tracing::instrument;
11
12use crate::{
13 config::AiModelConfig,
14 errors::{AppError, Result, UserFacingError},
15};
16
17mod anthropic;
18mod gemini;
19mod ollama;
20mod openai;
21
22pub trait AiProviderBase: Send + Sync {
24 fn provider_name(&self) -> &'static str;
26
27 fn auth_header(&self, api_key: String) -> (HeaderName, String);
29
30 fn api_key_env_var_name(&self) -> &str;
32
33 fn build_request(
35 &self,
36 client: &Client,
37 sys_prompt: &str,
38 user_prompt: &str,
39 json_schema: &Schema,
40 ) -> RequestBuilder;
41}
42
43#[trait_variant::make(Send)]
45pub trait AiProvider: AiProviderBase {
46 async fn parse_response<T>(&self, res: Response) -> Result<T>
48 where
49 Self: Sized,
50 T: DeserializeOwned + JsonSchema + Debug;
51}
52
53#[cfg_attr(debug_assertions, derive(Debug))]
55pub struct AiClient<'a> {
56 inner: Client,
57 primary_alias: &'a str,
58 primary: &'a AiModelConfig,
59 fallback_alias: &'a str,
60 fallback: Option<&'a AiModelConfig>,
61}
62impl<'a> AiClient<'a> {
63 pub fn new(
65 primary_alias: &'a str,
66 primary: &'a AiModelConfig,
67 fallback_alias: &'a str,
68 fallback: Option<&'a AiModelConfig>,
69 ) -> Result<Self> {
70 let mut headers = HeaderMap::new();
72 headers.append(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
73
74 let inner = ClientBuilder::new()
76 .connect_timeout(Duration::from_secs(2))
77 .timeout(Duration::from_secs(5 * 60))
78 .user_agent("intelli-shell")
79 .default_headers(headers)
80 .build()
81 .wrap_err("Couldn't build AI client")?;
82
83 Ok(AiClient {
84 inner,
85 primary_alias,
86 primary,
87 fallback_alias,
88 fallback,
89 })
90 }
91
92 #[instrument(skip_all)]
94 pub async fn generate_command_suggestions(
95 &self,
96 sys_prompt: &str,
97 user_prompt: &str,
98 ) -> Result<CommandSuggestions> {
99 self.generate_content(sys_prompt, user_prompt).await
100 }
101
102 #[instrument(skip_all)]
104 pub async fn generate_command_fix(&self, sys_prompt: &str, user_prompt: &str) -> Result<CommandFix> {
105 self.generate_content(sys_prompt, user_prompt).await
106 }
107
108 async fn generate_content<T>(&self, sys_prompt: &str, user_prompt: &str) -> Result<T>
112 where
113 T: DeserializeOwned + JsonSchema + Debug,
114 {
115 let primary_result = self.execute_request(self.primary, sys_prompt, user_prompt).await;
117
118 if let Err(AppError::UserFacing(UserFacingError::AiRateLimit)) = &primary_result {
120 if let Some(fallback) = self.fallback {
122 tracing::warn!(
123 "Primary model ({}) rate-limited, retrying with fallback ({})",
124 self.primary_alias,
125 self.fallback_alias
126 );
127 return self.execute_request(fallback, sys_prompt, user_prompt).await;
128 }
129 }
130
131 primary_result
133 }
134
135 #[instrument(skip_all, fields(provider = config.provider().provider_name()))]
137 async fn execute_request<T>(&self, config: &AiModelConfig, sys_prompt: &str, user_prompt: &str) -> Result<T>
138 where
139 T: DeserializeOwned + JsonSchema + Debug,
140 {
141 let provider = config.provider();
142
143 let json_schema = build_json_schema_for::<T>()?;
145
146 let mut req_builder = provider.build_request(&self.inner, sys_prompt, user_prompt, &json_schema);
148
149 let api_key_env = provider.api_key_env_var_name();
151 if let Ok(api_key) = env::var(api_key_env) {
152 let (header_name, header_value) = provider.auth_header(api_key);
153 let mut header_value =
154 HeaderValue::from_str(&header_value).wrap_err_with(|| format!("Invalid '{api_key_env}' value"))?;
155 header_value.set_sensitive(true);
156 req_builder = req_builder.header(header_name, header_value);
157 }
158
159 let req = req_builder.build().wrap_err("Couldn't build api request")?;
161
162 tracing::debug!("Calling {} API: {}", provider.provider_name(), req.url());
164 let res = self.inner.execute(req).await.map_err(|err| {
165 if err.is_timeout() {
166 tracing::error!("Request timeout: {err:?}");
167 UserFacingError::AiRequestTimeout
168 } else if err.is_connect() {
169 tracing::error!("Couldn't connect to the API: {err:?}");
170 UserFacingError::AiRequestFailed(String::from("error connecting to the provider"))
171 } else {
172 tracing::error!("Couldn't perform the request: {err:?}");
173 UserFacingError::AiRequestFailed(err.to_string())
174 }
175 })?;
176
177 if !res.status().is_success() {
179 let status = res.status();
180 let status_str = status.as_str();
181 let body = res.text().await.unwrap_or_default();
182 if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
183 tracing::warn!(
184 "Got response [{status_str}] {}",
185 status.canonical_reason().unwrap_or_default()
186 );
187 tracing::debug!("{body}");
188 return Err(
189 UserFacingError::AiMissingOrInvalidApiKey(provider.api_key_env_var_name().to_string()).into(),
190 );
191 } else if status == StatusCode::TOO_MANY_REQUESTS {
192 tracing::info!("Got response [{status_str}] Too Many Requests");
193 tracing::debug!("{body}");
194 return Err(UserFacingError::AiRateLimit.into());
195 } else if status == StatusCode::BAD_REQUEST {
196 tracing::error!("Got response [{status_str}] Bad Request:\n{body}");
197 return Err(eyre!("Bad request while fetching {} API:\n{body}", provider.provider_name()).into());
198 } else if let Some(reason) = status.canonical_reason() {
199 tracing::error!("Got response [{status_str}] {reason}:\n{body}");
200 return Err(
201 UserFacingError::AiRequestFailed(format!("received {status_str} {reason} response")).into(),
202 );
203 } else {
204 tracing::error!("Got response [{status_str}]:\n{body}");
205 return Err(UserFacingError::AiRequestFailed(format!("received {status_str} response")).into());
206 }
207 }
208
209 match &config {
211 AiModelConfig::Openai(conf) => conf.parse_response(res).await,
212 AiModelConfig::Gemini(conf) => conf.parse_response(res).await,
213 AiModelConfig::Anthropic(conf) => conf.parse_response(res).await,
214 AiModelConfig::Ollama(conf) => conf.parse_response(res).await,
215 }
216 }
217}
218
219#[derive(Debug, Deserialize, JsonSchema)]
220pub struct CommandSuggestions {
221 pub suggestions: Vec<CommandSuggestion>,
223}
224
225#[derive(Debug, Deserialize, JsonSchema)]
227pub struct CommandSuggestion {
228 pub description: String,
231 pub command: String,
235}
236
237#[derive(Debug, Deserialize, JsonSchema)]
239pub struct CommandFix {
240 pub summary: String,
243 pub diagnosis: String,
247 pub proposal: String,
250 pub fixed_command: String,
254}
255
256fn build_json_schema_for<T: JsonSchema>() -> Result<Schema> {
258 let mut schema = schema_for!(T);
260
261 let root = schema.as_object_mut().wrap_err("The type must be an object")?;
263 root.insert("additionalProperties".into(), false.into());
264
265 if let Some(defs) = root.get_mut("$defs") {
267 for definition in defs.as_object_mut().wrap_err("Expected objects at $defs")?.values_mut() {
268 if let Some(def_obj) = definition.as_object_mut()
269 && def_obj.get("type").and_then(|t| t.as_str()) == Some("object")
270 {
271 def_obj.insert("additionalProperties".into(), false.into());
272 }
273 }
274 }
275
276 Ok(schema)
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_command_suggestions_schema() {
285 let schema = build_json_schema_for::<CommandSuggestions>().unwrap();
286 println!("{}", serde_json::to_string_pretty(&schema).unwrap());
287 }
288
289 #[test]
290 fn test_command_fix_schema() {
291 let schema = build_json_schema_for::<CommandFix>().unwrap();
292 println!("{}", serde_json::to_string_pretty(&schema).unwrap());
293 }
294}