git_commit_helper/llm.rs
1// ************************************************************************** //
2// //
3// ::: :::::::: //
4// llm.rs :+: :+: :+: //
5// +:+ +:+ +:+ //
6// By: dfine <coding@dfine.tech> +#+ +:+ +#+ //
7// +#+#+#+#+#+ +#+ //
8// Created: 2025/05/10 19:12:36 by dfine #+# #+# //
9// Updated: 2025/05/10 19:12:37 by dfine ### ########.fr //
10// //
11// ************************************************************************** //
12
13use async_openai::{
14 Client as OpenAIClient,
15 config::OpenAIConfig,
16 types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs},
17};
18use ollama_rs::{IntoUrlSealed, Ollama, generation::completion::request::GenerationRequest};
19use std::error::Error;
20
21/// Sends a prompt to the OpenAI chat API and returns the generated response as a string.
22///
23/// This function uses the `async-openai` crate to interact with a chat completion endpoint
24/// (e.g., GPT-4, GPT-4o, GPT-3.5-turbo). The base URL can be overridden via the
25/// `OPENAI_BASE_URL` environment variable.
26///
27/// # Arguments
28///
29/// * `prompt` - The text prompt to send to the model.
30/// * `model` - The model ID to use (e.g., `"gpt-4o"`, `"gpt-3.5-turbo"`).
31/// * `max_token` - Maximum number of tokens allowed in the response.
32///
33/// # Returns
34///
35/// A `Result` containing the generated string response on success, or an error on failure.
36///
37/// # Errors
38///
39/// This function will return an error if the request fails, the environment variable
40/// is misconfigured, or if the response cannot be parsed correctly.
41///
42/// # Example
43///
44/// ```no_run
45/// use git_commit_helper::call_openai;
46///
47/// #[tokio::main]
48/// async fn main() {
49/// let prompt = "Summarize the following diff...";
50/// let model = "gpt-4o";
51/// let max_token = 2048;
52///
53/// match call_openai(prompt, model, max_token).await {
54/// Ok(response) => println!("LLM response: {}", response),
55/// Err(e) => eprintln!("Error calling OpenAI: {}", e),
56/// }
57/// }
58/// ```
59pub async fn call_openai(
60 prompt: &str,
61 model: &str,
62 max_token: u32,
63) -> Result<String, Box<dyn Error>> {
64 let base_url = std::env::var("OPENAI_BASE_URL")
65 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
66 let config = OpenAIConfig::default().with_api_base(base_url);
67 let client = OpenAIClient::with_config(config);
68 let request = CreateChatCompletionRequestArgs::default()
69 .max_tokens(max_token)
70 .model(model)
71 .messages([ChatCompletionRequestUserMessageArgs::default()
72 .content(prompt)
73 .build()?
74 .into()])
75 .build()?;
76 let response = client.chat().create(request).await?;
77 Ok(response
78 .choices
79 .first()
80 .and_then(|c| c.message.content.clone())
81 .unwrap_or_default())
82}
83
84/// Sends a prompt to the Ollama API and returns the generated response as a string.
85///
86/// This function uses the `ollama_rs` crate to interact with Ollama's generation endpoint.
87/// The base URL can be overridden via the `OLLAMA_BASE_URL` environment variable.
88///
89/// # Arguments
90///
91/// * `prompt` - The text prompt to send to the model.
92/// * `model` - The model ID to use for the request.
93/// * `_max_token` - Currently unused parameter (to maintain function signature consistency).
94///
95/// # Returns
96///
97/// A `Result` containing the generated string response on success, or an error on failure.
98///
99/// # Errors
100///
101/// This function will return an error if the request fails, the environment variable
102/// is misconfigured, or if the response cannot be handled correctly.
103pub async fn call_ollama(
104 prompt: &str,
105 model: &str,
106 _max_token: u32,
107) -> Result<String, Box<dyn Error>> {
108 let base_url =
109 std::env::var("OLLAMA_BASE_URL").unwrap_or_else(|_| "http://localhost:11434".to_string());
110 let url = base_url.into_url()?;
111 let client = Ollama::from_url(url);
112 let request = client
113 .generate(GenerationRequest::new(model.to_string(), prompt))
114 .await?;
115 Ok(request.response)
116}
117
118pub async fn call_llm(
119 provider: &str,
120 prompt: &str,
121 model: &str,
122 max_token: u32,
123) -> Result<String, Box<dyn Error>> {
124 match provider {
125 "ollama" => call_ollama(prompt, model, max_token).await,
126 _ => call_openai(prompt, model, max_token).await,
127 }
128}