git_commit_helper/
llm.rs

1// ************************************************************************** //
2//                                                                            //
3//                                                        :::      ::::::::   //
4//   llm.rs                                             :+:      :+:    :+:   //
5//                                                    +:+ +:+         +:+     //
6//   By: dfine <coding@dfine.tech>                  +#+  +:+       +#+        //
7//                                                +#+#+#+#+#+   +#+           //
8//   Created: 2025/05/10 19:12:36 by dfine             #+#    #+#             //
9//   Updated: 2025/05/10 19:12:37 by dfine            ###   ########.fr       //
10//                                                                            //
11// ************************************************************************** //
12
13use async_openai::{
14    Client as OpenAIClient,
15    config::OpenAIConfig,
16    types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs},
17};
18use ollama_rs::{IntoUrlSealed, Ollama, generation::completion::request::GenerationRequest};
19use std::error::Error;
20
21/// Sends a prompt to the OpenAI chat API and returns the generated response as a string.
22///
23/// This function uses the `async-openai` crate to interact with a chat completion endpoint
24/// (e.g., GPT-4, GPT-4o, GPT-3.5-turbo). The base URL can be overridden via the
25/// `OPENAI_BASE_URL` environment variable.
26///
27/// # Arguments
28///
29/// * `prompt` - The text prompt to send to the model.
30/// * `model` - The model ID to use (e.g., `"gpt-4o"`, `"gpt-3.5-turbo"`).
31/// * `max_token` - Maximum number of tokens allowed in the response.
32///
33/// # Returns
34///
35/// A `Result` containing the generated string response on success, or an error on failure.
36///
37/// # Errors
38///
39/// This function will return an error if the request fails, the environment variable
40/// is misconfigured, or if the response cannot be parsed correctly.
41///
42/// # Example
43///
44/// ```no_run
45/// use git_commit_helper::call_openai;
46///
47/// #[tokio::main]
48/// async fn main() {
49///     let prompt = "Summarize the following diff...";
50///     let model = "gpt-4o";
51///     let max_token = 2048;
52///
53///     match call_openai(prompt, model, max_token).await {
54///         Ok(response) => println!("LLM response: {}", response),
55///         Err(e) => eprintln!("Error calling OpenAI: {}", e),
56///     }
57/// }
58/// ```
59pub async fn call_openai(
60    prompt: &str,
61    model: &str,
62    max_token: u32,
63) -> Result<String, Box<dyn Error>> {
64    let base_url = std::env::var("OPENAI_BASE_URL")
65        .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
66    let config = OpenAIConfig::default().with_api_base(base_url);
67    let client = OpenAIClient::with_config(config);
68    let request = CreateChatCompletionRequestArgs::default()
69        .max_tokens(max_token)
70        .model(model)
71        .messages([ChatCompletionRequestUserMessageArgs::default()
72            .content(prompt)
73            .build()?
74            .into()])
75        .build()?;
76    let response = client.chat().create(request).await?;
77    Ok(response
78        .choices
79        .first()
80        .and_then(|c| c.message.content.clone())
81        .unwrap_or_default())
82}
83
84/// Sends a prompt to the Ollama API and returns the generated response as a string.
85///
86/// This function uses the `ollama_rs` crate to interact with Ollama's generation endpoint.
87/// The base URL can be overridden via the `OLLAMA_BASE_URL` environment variable.
88///
89/// # Arguments
90///
91/// * `prompt` - The text prompt to send to the model.
92/// * `model` - The model ID to use for the request.
93/// * `_max_token` - Currently unused parameter (to maintain function signature consistency).
94///
95/// # Returns
96///
97/// A `Result` containing the generated string response on success, or an error on failure.
98///
99/// # Errors
100///
101/// This function will return an error if the request fails, the environment variable
102/// is misconfigured, or if the response cannot be handled correctly.
103pub async fn call_ollama(
104    prompt: &str,
105    model: &str,
106    _max_token: u32,
107) -> Result<String, Box<dyn Error>> {
108    let base_url =
109        std::env::var("OLLAMA_BASE_URL").unwrap_or_else(|_| "http://localhost:11434".to_string());
110    let url = base_url.into_url()?;
111    let client = Ollama::from_url(url);
112    let request = client
113        .generate(GenerationRequest::new(model.to_string(), prompt))
114        .await?;
115    Ok(request.response)
116}
117
118pub async fn call_llm(
119    provider: &str,
120    prompt: &str,
121    model: &str,
122    max_token: u32,
123) -> Result<String, Box<dyn Error>> {
124    match provider {
125        "ollama" => call_ollama(prompt, model, max_token).await,
126        _ => call_openai(prompt, model, max_token).await,
127    }
128}