intelli_shell/ai/
openai.rs1use std::fmt::Debug;
2
3use color_eyre::eyre::Context;
4use reqwest::{
5 Client, RequestBuilder, Response,
6 header::{self, HeaderName},
7};
8use schemars::{JsonSchema, Schema};
9use serde::{Deserialize, de::DeserializeOwned};
10use serde_json::{Value as Json, json};
11
12use super::{AiProvider, AiProviderBase};
13use crate::{
14 config::OpenAiModelConfig,
15 errors::{Result, UserFacingError},
16};
17
18impl AiProviderBase for OpenAiModelConfig {
19 fn provider_name(&self) -> &'static str {
20 "OpenAI"
21 }
22
23 fn auth_header(&self, api_key: String) -> (HeaderName, String) {
24 (header::AUTHORIZATION, format!("Bearer {api_key}"))
25 }
26
27 fn api_key_env_var_name(&self) -> &str {
28 &self.api_key_env
29 }
30
31 fn build_request(
32 &self,
33 client: &Client,
34 sys_prompt: &str,
35 user_prompt: &str,
36 json_schema: &Schema,
37 ) -> RequestBuilder {
38 let request_body = json!({
41 "model": self.model,
42 "messages": [
43 {
44 "role": "system",
45 "content": sys_prompt
46 },
47 {
48 "role": "user",
49 "content": user_prompt
50 }
51 ],
52 "response_format": {
53 "type": "json_schema",
54 "json_schema": {
55 "name": "command_suggestions",
56 "strict": true,
57 "schema": json_schema
58 }
59 }
60 });
61
62 tracing::trace!("Request:\n{request_body:#}");
63
64 let url = format!("{}/chat/completions", self.url);
66
67 client.post(url).json(&request_body)
69 }
70}
71
72impl AiProvider for OpenAiModelConfig {
73 async fn parse_response<T>(&self, res: Response) -> Result<T>
74 where
75 T: DeserializeOwned + JsonSchema + Debug,
76 {
77 let res: Json = res.json().await.wrap_err("OpenAI response not a json")?;
79 tracing::trace!("Response:\n{res:#}");
80 let mut res: OpenAiResponse = serde_json::from_value(res).wrap_err("Couldn't parse OpenAI response")?;
81
82 if res.choices.is_empty() {
84 tracing::error!("Response got no choices: {res:?}");
85 return Err(UserFacingError::AiRequestFailed(String::from("received response with no choices")).into());
86 } else if res.choices.len() > 1 {
87 tracing::warn!("Response got {} choices", res.choices.len());
88 }
89
90 let choice = res.choices.remove(0);
91 if choice.finish_reason != "stop" {
92 tracing::error!("OpenAI response got an invalid finish reason: {}", choice.finish_reason);
93 return Err(UserFacingError::AiRequestFailed(format!(
94 "couldn't generate a valid response: {}",
95 choice.finish_reason
96 ))
97 .into());
98 }
99
100 if let Some(refusal) = choice.message.refusal
101 && !refusal.is_empty()
102 {
103 tracing::error!("OpenAI refused to answer: {refusal}");
104 return Err(UserFacingError::AiRequestFailed(format!("response refused: {refusal}")).into());
105 }
106
107 let Some(message) = choice.message.content.filter(|c| !c.trim().is_empty()) else {
108 tracing::error!("OpenAI returned an empty response");
109 return Err(UserFacingError::AiRequestFailed(String::from("received an empty response")).into());
110 };
111
112 Ok(serde_json::from_str(&message).map_err(|err| {
114 tracing::error!("Couldn't parse API response into the expected format: {err}\nMessage:\n{message}");
115 UserFacingError::AiRequestFailed(String::from("couldn't parse api response into the expected format"))
116 })?)
117 }
118}
119
120#[derive(Debug, Deserialize)]
121struct OpenAiResponse {
122 choices: Vec<OpenAiChoice>,
123}
124
125#[derive(Debug, Deserialize)]
126struct OpenAiChoice {
127 message: OpenAiResponseMessage,
128 finish_reason: String,
129}
130
131#[derive(Debug, Deserialize)]
132struct OpenAiResponseMessage {
133 #[serde(default)]
134 refusal: Option<String>,
135 #[serde(default)]
136 content: Option<String>,
137}
138
139#[cfg(test)]
140mod tests {
141 use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
142
143 use super::*;
144 use crate::{ai::AiClient, config::AiModelConfig};
145
146 #[tokio::test]
147 #[ignore] async fn test_openai_api() -> Result<()> {
149 tracing_subscriber::registry()
150 .with(tracing_subscriber::fmt::layer().compact())
151 .init();
152 let config = AiModelConfig::Openai(OpenAiModelConfig {
153 model: "gpt-4.1-nano".into(),
154 url: "https://api.openai.com/v1".into(),
155 api_key_env: "OPENAI_API_KEY".into(),
156 });
157 let client = AiClient::new("test", &config, "", None)?;
158 let res = client
159 .generate_command_suggestions(
160 "you're a cli expert, that will proide command suggestions based on what the user want to do",
161 "undo last n amount of commits",
162 )
163 .await?;
164 tracing::info!("Suggestions:");
165 for command in res.suggestions {
166 tracing::info!(" # {}", command.description);
167 tracing::info!(" {}", command.command);
168 }
169 Ok(())
170 }
171}