argyph-locate 1.0.4

Local-first MCP server giving AI coding agents fast, structured, and semantic context over any codebase.
Documentation
//! OpenAI / OpenAI-compatible provider.

use crate::smart::model::{
    redact, ApiKey, LocateModel, LocateModelError, Message, ModelStep, Role,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};

const MAX_RETRIES: u32 = 3;

pub struct OpenAiModel {
    api_key: ApiKey,
    model: String,
    endpoint: String,
    client: reqwest::Client,
}

impl OpenAiModel {
    pub fn from_env(model: String, endpoint: Option<String>) -> Result<Self, LocateModelError> {
        let key = std::env::var("OPENAI_API_KEY")
            .map_err(|_| LocateModelError::Provider("OPENAI_API_KEY not set".into()))?;
        Ok(Self {
            api_key: ApiKey::new(key),
            model,
            endpoint: endpoint
                .unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".into()),
            client: reqwest::Client::new(),
        })
    }
}

#[derive(Serialize)]
struct ChatRequest<'a> {
    model: &'a str,
    messages: &'a [ChatMessage],
    temperature: f32,
    response_format: ResponseFormat,
}
#[derive(Serialize)]
struct ChatMessage {
    role: String,
    content: String,
}
#[derive(Serialize)]
struct ResponseFormat {
    #[serde(rename = "type")]
    kind: String,
}

#[derive(Deserialize)]
struct ChatResponse {
    choices: Vec<Choice>,
}
#[derive(Deserialize)]
struct Choice {
    message: ChoiceMessage,
}
#[derive(Deserialize)]
struct ChoiceMessage {
    content: String,
}

#[async_trait]
impl LocateModel for OpenAiModel {
    async fn step(&self, messages: &[Message]) -> Result<ModelStep, LocateModelError> {
        let chat_msgs: Vec<ChatMessage> = messages
            .iter()
            .map(|m| ChatMessage {
                role: match m.role {
                    Role::System => "system",
                    Role::User => "user",
                    Role::Assistant => "assistant",
                    Role::Tool => "tool",
                }
                .to_string(),
                content: m.content.clone(),
            })
            .collect();

        let body = ChatRequest {
            model: &self.model,
            messages: &chat_msgs,
            temperature: 0.0,
            response_format: ResponseFormat {
                kind: "json_object".into(),
            },
        };

        let mut attempt: u32 = 0;
        loop {
            let resp = self
                .client
                .post(&self.endpoint)
                .bearer_auth(self.api_key.expose())
                .json(&body)
                .send()
                .await
                .map_err(|e| LocateModelError::Provider(redact(&e.to_string(), &self.api_key)))?;

            let status = resp.status();
            if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
                let retry = resp
                    .headers()
                    .get("retry-after")
                    .and_then(|v| v.to_str().ok())
                    .and_then(|s| s.parse().ok())
                    .unwrap_or(2);
                return Err(LocateModelError::RateLimit {
                    retry_after_ms: retry * 1000,
                });
            }
            if status.is_server_error() && attempt < MAX_RETRIES {
                attempt += 1;
                let backoff_ms = 200u64 * (1u64 << (attempt - 1));
                tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
                continue;
            }
            if !status.is_success() {
                let text = resp.text().await.unwrap_or_default();
                let scrubbed = redact(&text, &self.api_key);
                return Err(LocateModelError::Provider(format!(
                    "HTTP {status}: {scrubbed}"
                )));
            }

            let parsed: ChatResponse = resp
                .json()
                .await
                .map_err(|e| LocateModelError::Parse(redact(&e.to_string(), &self.api_key)))?;
            let raw = parsed
                .choices
                .into_iter()
                .next()
                .ok_or_else(|| LocateModelError::Parse("no choices".into()))?
                .message
                .content;

            return parse_model_output(&raw);
        }
    }
}

pub(crate) fn parse_model_output(raw: &str) -> Result<ModelStep, LocateModelError> {
    let v: serde_json::Value = serde_json::from_str(raw)
        .map_err(|e| LocateModelError::Parse(format!("not JSON: {e}: {raw}")))?;
    if let Some(t) = v.get("tool") {
        let name = t
            .get("name")
            .and_then(|x| x.as_str())
            .ok_or_else(|| LocateModelError::Parse("tool.name missing".into()))?
            .to_string();
        let arguments = t.get("arguments").cloned().unwrap_or(serde_json::json!({}));
        return Ok(ModelStep::ToolCall {
            id: format!("call_{}", rand_id()),
            name,
            arguments,
        });
    }
    if let Some(f) = v.get("final") {
        let ids: Vec<String> = serde_json::from_value(
            f.get("selected_node_ids")
                .cloned()
                .unwrap_or(serde_json::json!([])),
        )
        .map_err(|e| LocateModelError::Parse(e.to_string()))?;
        let summary = f
            .get("reasoning_summary")
            .and_then(|x| x.as_str())
            .unwrap_or("")
            .to_string();
        return Ok(ModelStep::Final {
            selected_node_ids: ids,
            reasoning_summary: summary,
        });
    }
    Err(LocateModelError::Parse(format!(
        "expected `tool` or `final` key: {raw}"
    )))
}

fn rand_id() -> String {
    use std::time::{SystemTime, UNIX_EPOCH};
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or(std::time::Duration::ZERO)
        .as_nanos()
        .to_string()
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
    use super::*;
    #[test]
    fn parses_tool_call() {
        let raw = r#"{"tool":{"name":"locate","arguments":{"query":"x"}}}"#;
        let step = parse_model_output(raw).unwrap();
        assert!(matches!(step, ModelStep::ToolCall { ref name, .. } if name == "locate"));
    }
    #[test]
    fn parses_final() {
        let raw = r#"{"final":{"selected_node_ids":["n1"],"reasoning_summary":"r"}}"#;
        let step = parse_model_output(raw).unwrap();
        assert!(
            matches!(step, ModelStep::Final { ref selected_node_ids, .. } if selected_node_ids == &["n1".to_string()])
        );
    }
    #[test]
    fn rejects_unknown_shape() {
        assert!(parse_model_output(r#"{"hello":"world"}"#).is_err());
    }

    #[test]
    fn api_key_redacts_in_display() {
        let k = ApiKey::new("sk-secret-12345");
        assert_eq!(format!("{k}"), "***redacted***");
        assert_eq!(format!("{k:?}"), "ApiKey(***redacted***)");
    }

    #[test]
    fn redact_scrubs_key_from_text() {
        let k = ApiKey::new("sk-secret-12345");
        let body = "Bearer sk-secret-12345 not authorized";
        assert_eq!(redact(body, &k), "Bearer ***redacted*** not authorized");
    }
}