1use std::{env, fmt::Debug, time::Duration};
2
3use color_eyre::eyre::{Context, ContextCompat, eyre};
4use reqwest::{
5 Client, ClientBuilder, RequestBuilder, Response, StatusCode,
6 header::{self, HeaderMap, HeaderName, HeaderValue},
7};
8use schemars::{JsonSchema, Schema, schema_for};
9use serde::{Deserialize, de::DeserializeOwned};
10use tokio_util::sync::CancellationToken;
11use tracing::instrument;
12
13use crate::{
14 config::AiModelConfig,
15 errors::{AppError, Result, UserFacingError},
16};
17
18mod anthropic;
19mod gemini;
20mod ollama;
21mod openai;
22
23pub trait AiProviderBase: Send + Sync {
25 fn provider_name(&self) -> &'static str;
27
28 fn auth_header(&self, api_key: String) -> (HeaderName, String);
30
31 fn api_key_env_var_name(&self) -> &str;
33
34 fn build_request(
36 &self,
37 client: &Client,
38 sys_prompt: &str,
39 user_prompt: &str,
40 json_schema: &Schema,
41 ) -> RequestBuilder;
42}
43
44#[trait_variant::make(Send)]
46pub trait AiProvider: AiProviderBase {
47 async fn parse_response<T>(&self, res: Response) -> Result<T>
49 where
50 Self: Sized,
51 T: DeserializeOwned + JsonSchema + Debug;
52}
53
54#[cfg_attr(debug_assertions, derive(Debug))]
56pub struct AiClient<'a> {
57 inner: Client,
58 primary_alias: &'a str,
59 primary: &'a AiModelConfig,
60 fallback_alias: &'a str,
61 fallback: Option<&'a AiModelConfig>,
62}
63impl<'a> AiClient<'a> {
64 pub fn new(
66 primary_alias: &'a str,
67 primary: &'a AiModelConfig,
68 fallback_alias: &'a str,
69 fallback: Option<&'a AiModelConfig>,
70 ) -> Result<Self> {
71 let mut headers = HeaderMap::new();
73 headers.append(header::CONTENT_TYPE, HeaderValue::from_static("application/json"));
74
75 let inner = ClientBuilder::new()
77 .connect_timeout(Duration::from_secs(5))
78 .timeout(Duration::from_secs(5 * 60))
79 .user_agent("intelli-shell")
80 .default_headers(headers)
81 .build()
82 .wrap_err("Couldn't build AI client")?;
83
84 Ok(AiClient {
85 inner,
86 primary_alias,
87 primary,
88 fallback_alias,
89 fallback,
90 })
91 }
92
93 #[instrument(skip_all)]
95 pub async fn generate_command_suggestions(
96 &self,
97 sys_prompt: &str,
98 user_prompt: &str,
99 cancellation_token: CancellationToken,
100 ) -> Result<CommandSuggestions> {
101 self.generate_content(sys_prompt, user_prompt, cancellation_token).await
102 }
103
104 #[instrument(skip_all)]
106 pub async fn generate_command_fix(
107 &self,
108 sys_prompt: &str,
109 user_prompt: &str,
110 cancellation_token: CancellationToken,
111 ) -> Result<CommandFix> {
112 self.generate_content(sys_prompt, user_prompt, cancellation_token).await
113 }
114
115 #[instrument(skip_all)]
117 pub async fn generate_completion_suggestion(
118 &self,
119 sys_prompt: &str,
120 user_prompt: &str,
121 cancellation_token: CancellationToken,
122 ) -> Result<VariableCompletionSuggestion> {
123 self.generate_content(sys_prompt, user_prompt, cancellation_token).await
124 }
125
126 async fn generate_content<T>(
130 &self,
131 sys_prompt: &str,
132 user_prompt: &str,
133 cancellation_token: CancellationToken,
134 ) -> Result<T>
135 where
136 T: DeserializeOwned + JsonSchema + Debug,
137 {
138 let primary_result = self
140 .execute_request(self.primary, sys_prompt, user_prompt, cancellation_token.clone())
141 .await;
142
143 if let Err(AppError::UserFacing(UserFacingError::AiRateLimit)) = &primary_result {
145 if let Some(fallback) = self.fallback {
147 tracing::warn!(
148 "Primary model ({}) rate-limited, retrying with fallback ({})",
149 self.primary_alias,
150 self.fallback_alias
151 );
152 return self
153 .execute_request(fallback, sys_prompt, user_prompt, cancellation_token)
154 .await;
155 }
156 }
157
158 if let Err(AppError::UserFacing(UserFacingError::AiUnavailable)) = &primary_result {
160 if let Some(fallback) = self.fallback {
162 tracing::warn!(
163 "Primary model ({}) unavailable, retrying with fallback ({})",
164 self.primary_alias,
165 self.fallback_alias
166 );
167 return self
168 .execute_request(fallback, sys_prompt, user_prompt, cancellation_token)
169 .await;
170 }
171 }
172
173 primary_result
175 }
176
177 #[instrument(skip_all, fields(provider = config.provider().provider_name()))]
179 async fn execute_request<T>(
180 &self,
181 config: &AiModelConfig,
182 sys_prompt: &str,
183 user_prompt: &str,
184 cancellation_token: CancellationToken,
185 ) -> Result<T>
186 where
187 T: DeserializeOwned + JsonSchema + Debug,
188 {
189 let provider = config.provider();
190
191 let json_schema = build_json_schema_for::<T>()?;
193
194 let mut req_builder = provider.build_request(&self.inner, sys_prompt, user_prompt, &json_schema);
196
197 let api_key_env = provider.api_key_env_var_name();
199 if let Ok(api_key) = env::var(api_key_env) {
200 let (header_name, header_value) = provider.auth_header(api_key);
201 let mut header_value =
202 HeaderValue::from_str(&header_value).wrap_err_with(|| format!("Invalid '{api_key_env}' value"))?;
203 header_value.set_sensitive(true);
204 req_builder = req_builder.header(header_name, header_value);
205 }
206
207 let req = req_builder.build().wrap_err("Couldn't build api request")?;
209
210 tracing::debug!("Calling {} API: {}", provider.provider_name(), req.url());
212 let res = tokio::select! {
213 biased;
214 _ = cancellation_token.cancelled() => {
215 return Err(UserFacingError::Cancelled.into());
216 }
217 res = self.inner.execute(req) => {
218 res.map_err(|err| {
219 if err.is_timeout() {
220 tracing::error!("Request timeout: {err:?}");
221 UserFacingError::AiRequestTimeout
222 } else if err.is_connect() {
223 tracing::error!("Couldn't connect to the API: {err:?}");
224 UserFacingError::AiRequestFailed(String::from("error connecting to the provider"))
225 } else {
226 tracing::error!("Couldn't perform the request: {err:?}");
227 UserFacingError::AiRequestFailed(err.to_string())
228 }
229 })?
230 }
231 };
232
233 if !res.status().is_success() {
235 let status = res.status();
236 let status_str = status.as_str();
237 let body = res.text().await.unwrap_or_default();
238 if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN {
239 tracing::warn!(
240 "Got response [{status_str}] {}",
241 status.canonical_reason().unwrap_or_default()
242 );
243 tracing::debug!("{body}");
244 return Err(
245 UserFacingError::AiMissingOrInvalidApiKey(provider.api_key_env_var_name().to_string()).into(),
246 );
247 } else if status == StatusCode::TOO_MANY_REQUESTS {
248 tracing::info!("Got response [{status_str}] Too Many Requests");
249 tracing::debug!("{body}");
250 return Err(UserFacingError::AiRateLimit.into());
251 } else if status == StatusCode::SERVICE_UNAVAILABLE {
252 tracing::info!("Got response [{status_str}] Service Unavailable");
253 tracing::debug!("{body}");
254 return Err(UserFacingError::AiUnavailable.into());
255 } else if status == StatusCode::BAD_REQUEST {
256 tracing::error!("Got response [{status_str}] Bad Request:\n{body}");
257 return Err(eyre!("Bad request while fetching {} API:\n{body}", provider.provider_name()).into());
258 } else if let Some(reason) = status.canonical_reason() {
259 tracing::error!("Got response [{status_str}] {reason}:\n{body}");
260 return Err(
261 UserFacingError::AiRequestFailed(format!("received {status_str} {reason} response")).into(),
262 );
263 } else {
264 tracing::error!("Got response [{status_str}]:\n{body}");
265 return Err(UserFacingError::AiRequestFailed(format!("received {status_str} response")).into());
266 }
267 }
268
269 match &config {
271 AiModelConfig::Openai(conf) => conf.parse_response(res).await,
272 AiModelConfig::Gemini(conf) => conf.parse_response(res).await,
273 AiModelConfig::Anthropic(conf) => conf.parse_response(res).await,
274 AiModelConfig::Ollama(conf) => conf.parse_response(res).await,
275 }
276 }
277}
278
279#[derive(Debug, Deserialize, JsonSchema)]
280pub struct CommandSuggestions {
281 pub suggestions: Vec<CommandSuggestion>,
283}
284
285#[derive(Debug, Deserialize, JsonSchema)]
287pub struct CommandSuggestion {
288 pub description: String,
291 pub command: String,
295}
296
297#[derive(Debug, Deserialize, JsonSchema)]
299pub struct CommandFix {
300 pub summary: String,
303 pub diagnosis: String,
307 pub proposal: String,
310 pub fixed_command: String,
314}
315
316#[derive(Debug, Deserialize, JsonSchema)]
318pub struct VariableCompletionSuggestion {
319 pub command: String,
321}
322
323fn build_json_schema_for<T: JsonSchema>() -> Result<Schema> {
325 let mut schema = schema_for!(T);
327
328 let root = schema.as_object_mut().wrap_err("The type must be an object")?;
330 root.insert("additionalProperties".into(), false.into());
331
332 if let Some(defs) = root.get_mut("$defs") {
334 for definition in defs.as_object_mut().wrap_err("Expected objects at $defs")?.values_mut() {
335 if let Some(def_obj) = definition.as_object_mut()
336 && def_obj.get("type").and_then(|t| t.as_str()) == Some("object")
337 {
338 def_obj.insert("additionalProperties".into(), false.into());
339 }
340 }
341 }
342
343 Ok(schema)
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn test_command_suggestions_schema() {
352 let schema = build_json_schema_for::<CommandSuggestions>().unwrap();
353 println!("{}", serde_json::to_string_pretty(&schema).unwrap());
354 }
355
356 #[test]
357 fn test_command_fix_schema() {
358 let schema = build_json_schema_for::<CommandFix>().unwrap();
359 println!("{}", serde_json::to_string_pretty(&schema).unwrap());
360 }
361}