commitbee 0.6.0

AI-powered commit message generator using tree-sitter semantic analysis and local LLMs
Documentation
// SPDX-FileCopyrightText: 2026 Sephyi <me@sephy.io>
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Commercial

use std::time::Duration;

use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;

use secrecy::{ExposeSecret, SecretString};

use crate::config::Config;
use crate::error::{Error, Result};

use super::MAX_RESPONSE_BYTES;

const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";

pub struct OpenAiProvider {
    client: Client,
    base_url: String,
    model: String,
    api_key: SecretString,
    temperature: f32,
    max_tokens: u32,
}

#[derive(Serialize)]
struct ChatRequest {
    model: String,
    messages: Vec<Message>,
    temperature: f32,
    max_tokens: u32,
    stream: bool,
}

#[derive(Serialize)]
struct Message {
    role: String,
    content: String,
}

#[derive(Deserialize)]
struct ChatChunk {
    choices: Vec<ChunkChoice>,
}

#[derive(Deserialize)]
struct ChunkChoice {
    delta: Delta,
    finish_reason: Option<String>,
}

#[derive(Deserialize)]
struct Delta {
    content: Option<String>,
}

impl OpenAiProvider {
    pub fn new(config: &Config) -> Result<Self> {
        let client = Client::builder()
            .timeout(Duration::from_secs(config.timeout_secs))
            .build()
            .map_err(|e| Error::Provider {
                provider: "openai".into(),
                message: format!("failed to build HTTP client: {e}"),
            })?;

        Ok(Self {
            client,
            base_url: config
                .openai_base_url
                .clone()
                .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
                .trim_end_matches('/')
                .to_string(),
            model: config.model.clone(),
            api_key: config.api_key.clone().unwrap_or_default(),
            temperature: config.temperature,
            max_tokens: config.num_predict,
        })
    }

    pub async fn verify_connection(&self) -> Result<()> {
        let url = format!("{}/models", self.base_url);

        let response = self
            .client
            .get(&url)
            .header(
                "Authorization",
                format!("Bearer {}", self.api_key.expose_secret()),
            )
            .send()
            .await
            .map_err(|e| Error::Provider {
                provider: "openai".into(),
                message: e.without_url().to_string(),
            })?;

        if response.status() == reqwest::StatusCode::UNAUTHORIZED {
            return Err(Error::Provider {
                provider: "openai".into(),
                message: "invalid API key".into(),
            });
        }

        Ok(())
    }

    pub async fn generate(
        &self,
        prompt: &str,
        system_prompt: &str,
        token_tx: mpsc::Sender<String>,
        cancel: CancellationToken,
    ) -> Result<String> {
        let url = format!("{}/chat/completions", self.base_url);

        let response = self
            .client
            .post(&url)
            .header(
                "Authorization",
                format!("Bearer {}", self.api_key.expose_secret()),
            )
            .json(&ChatRequest {
                model: self.model.clone(),
                messages: vec![
                    Message {
                        role: "system".into(),
                        content: system_prompt.into(),
                    },
                    Message {
                        role: "user".into(),
                        content: prompt.to_string(),
                    },
                ],
                temperature: self.temperature,
                max_tokens: self.max_tokens,
                stream: true,
            })
            .send()
            .await
            .map_err(|e| {
                if e.is_timeout() {
                    Error::Provider {
                        provider: "openai".into(),
                        message: "request timed out".into(),
                    }
                } else {
                    Error::Provider {
                        provider: "openai".into(),
                        message: e.without_url().to_string(),
                    }
                }
            })?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response
                .text()
                .await
                .unwrap_or_else(|e| format!("(failed to read body: {e})"));
            return Err(Error::Provider {
                provider: "openai".into(),
                message: format!("HTTP {status}: {body}"),
            });
        }

        let mut stream = response.bytes_stream();
        let mut full_response = String::new();
        let mut line_buffer = String::new();

        loop {
            tokio::select! {
                _ = cancel.cancelled() => {
                    return Err(Error::Cancelled);
                }
                chunk = stream.next() => {
                    let Some(chunk) = chunk else { break };

                    let chunk = chunk.map_err(|e| Error::Provider {
                        provider: "openai".into(),
                        message: e.without_url().to_string(),
                    })?;

                    line_buffer.push_str(&String::from_utf8_lossy(&chunk));

                    if line_buffer.len() > MAX_RESPONSE_BYTES {
                        return Err(Error::Provider {
                            provider: "openai".into(),
                            message: "line buffer exceeded 1 MB limit".into(),
                        });
                    }

                    while let Some(newline_pos) = line_buffer.find('\n') {
                        // Parse from slice to avoid allocating a String per line
                        let result = {
                            let line = line_buffer[..newline_pos].trim();
                            if line.is_empty() || line == "data: [DONE]" {
                                None
                            } else if let Some(data) = line.strip_prefix("data: ") {
                                serde_json::from_str::<ChatChunk>(data).ok()
                            } else {
                                None
                            }
                        };
                        // Shift buffer in-place (no allocation)
                        line_buffer.drain(..=newline_pos);

                        if let Some(chunk) = result {
                            for choice in &chunk.choices {
                                if let Some(ref content) = choice.delta.content {
                                    let _ = token_tx.send(content.clone()).await;
                                    full_response.push_str(content);
                                }
                                if full_response.len() > MAX_RESPONSE_BYTES {
                                    return Err(Error::Provider {
                                        provider: "openai".into(),
                                        message: "response exceeded 1 MB limit".into(),
                                    });
                                }
                                if choice.finish_reason.is_some() {
                                    return Ok(full_response.trim().to_string());
                                }
                            }
                        }
                    }
                }
            }
        }

        // Handle any remaining content in buffer after stream ends
        if !line_buffer.is_empty() {
            let line = line_buffer.trim();
            if !line.is_empty()
                && line != "data: [DONE]"
                && let Some(data) = line.strip_prefix("data: ")
                && let Ok(chunk) = serde_json::from_str::<ChatChunk>(data)
            {
                for choice in &chunk.choices {
                    if let Some(ref content) = choice.delta.content {
                        full_response.push_str(content);
                    }
                }
            }
        }

        Ok(full_response.trim().to_string())
    }

    pub fn name(&self) -> &str {
        "openai"
    }
}