menta 0.0.1

Minimal Rust library for non-UI LLM and AI primitives
Documentation
use reqwest::blocking::Client;
use serde_json::{json, Value};

use crate::{
    EmbeddingModel, EmbeddingResult, Error, FinishReason, LanguageModel, ModelMessage,
    ModelRequest, ModelResponse, Part, ProviderRegistration, Result, Role, ToolChoice,
    ToolDefinition, ToolSchema, Usage,
};

#[derive(Clone, Debug)]
pub struct OpenAiLanguageModel {
    model_id: String,
}

#[derive(Clone, Debug)]
pub struct OpenAiEmbeddingModel {
    model_id: String,
}

impl LanguageModel for OpenAiLanguageModel {
    fn model_id(&self) -> &str {
        &self.model_id
    }

    fn generate(&self, request: &ModelRequest) -> Result<ModelResponse> {
        let (status, body) = openai_post_json(
            "https://api.openai.com/v1/chat/completions",
            openai_chat_request(&self.model_id, request),
        )?;

        if !(200..300).contains(&status) {
            return Err(Error::Api(openai_error_message(&body)));
        }

        openai_chat_response_to_model_response(&self.model_id, &body)
    }
}

impl EmbeddingModel for OpenAiEmbeddingModel {
    fn model_id(&self) -> &str {
        &self.model_id
    }

    fn embed(&self, value: &str) -> Result<EmbeddingResult> {
        let (status, body) = openai_post_json(
            "https://api.openai.com/v1/embeddings",
            json!({
                "model": self.model_id,
                "input": value,
            }),
        )?;

        if !(200..300).contains(&status) {
            return Err(Error::Api(openai_error_message(&body)));
        }

        let embedding = body
            .get("data")
            .and_then(Value::as_array)
            .and_then(|items| items.first())
            .and_then(|item| item.get("embedding"))
            .and_then(Value::as_array)
            .ok_or_else(|| Error::Parse("missing embedding data".to_string()))?
            .iter()
            .map(|value| {
                value
                    .as_f64()
                    .map(|value| value as f32)
                    .ok_or_else(|| Error::Parse("invalid embedding value".to_string()))
            })
            .collect::<Result<Vec<_>>>()?;

        Ok(EmbeddingResult {
            embedding,
            usage: openai_usage(&body),
        })
    }
}

fn openai_language_model(model_id: &str) -> Result<Box<dyn LanguageModel>> {
    if model_id.is_empty() {
        return Err(Error::UnsupportedModel("openai/".to_string()));
    }

    Ok(Box::new(OpenAiLanguageModel {
        model_id: model_id.to_string(),
    }))
}

fn openai_embedding_model(model_id: &str) -> Result<Box<dyn EmbeddingModel>> {
    if model_id.is_empty() {
        return Err(Error::UnsupportedModel("openai/".to_string()));
    }

    Ok(Box::new(OpenAiEmbeddingModel {
        model_id: model_id.to_string(),
    }) as Box<dyn EmbeddingModel>)
}

inventory::submit! {
    ProviderRegistration {
        id: "openai",
        language_model: openai_language_model,
        embedding_model: openai_embedding_model,
    }
}

fn openai_api_key() -> Result<String> {
    std::env::var("OPENAI_API_KEY").map_err(|_| Error::MissingEnvironmentVariable("OPENAI_API_KEY"))
}

fn openai_post_json(url: &'static str, body: Value) -> Result<(u16, Value)> {
    let api_key = openai_api_key()?;

    std::thread::spawn(move || {
        let response = Client::builder()
            .build()
            .map_err(|error| Error::Http(error.to_string()))?
            .post(url)
            .bearer_auth(api_key)
            .json(&body)
            .send()
            .map_err(|error| Error::Http(error.to_string()))?;

        let status = response.status().as_u16();
        let body = response
            .json()
            .map_err(|error| Error::Json(error.to_string()))?;

        Ok((status, body))
    })
    .join()
    .map_err(|_| Error::Http("openai request thread panicked".to_string()))?
}

fn openai_chat_request(model_id: &str, request: &ModelRequest) -> Value {
    let mut body = json!({
        "model": model_id,
        "messages": openai_messages(&request.messages),
    });

    if let Some(temperature) = request.settings.temperature {
        body["temperature"] = json!(temperature);
    }

    if let Some(max_tokens) = request.settings.max_output_tokens {
        body["max_tokens"] = json!(max_tokens);
    }

    if !request.tools.is_empty() {
        body["tools"] = Value::Array(request.tools.iter().map(openai_tool_definition).collect());
        body["tool_choice"] = openai_tool_choice(&request.tool_choice);
    }

    body
}

fn openai_messages(messages: &[ModelMessage]) -> Vec<Value> {
    messages
        .iter()
        .map(|message| match message.role {
            Role::System => json!({ "role": "system", "content": message.text() }),
            Role::User => json!({ "role": "user", "content": message.text() }),
            Role::Assistant => {
                let content = message.text();
                let tool_calls = message
                    .parts
                    .iter()
                    .filter_map(|part| match part {
                        Part::ToolCall(call) => Some(json!({
                            "id": call.id,
                            "type": "function",
                            "function": {
                                "name": call.name,
                                "arguments": call.input,
                            }
                        })),
                        _ => None,
                    })
                    .collect::<Vec<_>>();

                if tool_calls.is_empty() {
                    json!({ "role": "assistant", "content": content })
                } else {
                    json!({
                        "role": "assistant",
                        "content": if content.is_empty() { Value::Null } else { Value::String(content) },
                        "tool_calls": tool_calls,
                    })
                }
            }
            Role::Tool => {
                let result = message.parts.iter().find_map(|part| match part {
                    Part::ToolResult(result) => Some(result),
                    _ => None,
                });

                match result {
                    Some(result) => json!({
                        "role": "tool",
                        "tool_call_id": result.call_id,
                        "content": result.output,
                    }),
                    None => json!({ "role": "tool", "content": message.text() }),
                }
            }
        })
        .collect()
}

fn openai_tool_definition(tool: &ToolDefinition) -> Value {
    json!({
        "type": "function",
        "function": {
            "name": tool.name,
            "description": tool.description,
            "parameters": tool_schema_json(&tool.input_schema),
        }
    })
}

fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
    match tool_choice {
        ToolChoice::Auto => json!("auto"),
        ToolChoice::None => json!("none"),
        ToolChoice::Required(name) => json!({
            "type": "function",
            "function": { "name": name }
        }),
    }
}

pub(crate) fn tool_schema_json(schema: &ToolSchema) -> Value {
    match schema {
        ToolSchema::String { description } => {
            json_with_description(json!({ "type": "string" }), description)
        }
        ToolSchema::Integer { description } => {
            json_with_description(json!({ "type": "integer" }), description)
        }
        ToolSchema::Number { description } => {
            json_with_description(json!({ "type": "number" }), description)
        }
        ToolSchema::Boolean { description } => {
            json_with_description(json!({ "type": "boolean" }), description)
        }
        ToolSchema::Array { description, items } => json_with_description(
            json!({ "type": "array", "items": tool_schema_json(items) }),
            description,
        ),
        ToolSchema::Object(object) => {
            let properties = object
                .fields
                .iter()
                .map(|field| {
                    let mut schema = tool_schema_json(&field.schema);
                    if let Some(description) = &field.description {
                        schema["description"] = json!(description);
                    }
                    (field.name.clone(), schema)
                })
                .collect::<serde_json::Map<String, Value>>();
            let required = object
                .fields
                .iter()
                .filter(|field| field.required)
                .map(|field| Value::String(field.name.clone()))
                .collect::<Vec<_>>();

            json_with_description(
                json!({
                    "type": "object",
                    "properties": properties,
                    "required": required,
                    "additionalProperties": false,
                }),
                &object.description,
            )
        }
    }
}

fn json_with_description(mut value: Value, description: &Option<String>) -> Value {
    if let Some(description) = description {
        value["description"] = json!(description);
    }
    value
}

fn openai_chat_response_to_model_response(model_id: &str, body: &Value) -> Result<ModelResponse> {
    let choice = body
        .get("choices")
        .and_then(Value::as_array)
        .and_then(|choices| choices.first())
        .ok_or_else(|| Error::Parse("missing choice".to_string()))?;
    let message = choice
        .get("message")
        .ok_or_else(|| Error::Parse("missing message".to_string()))?;

    let mut parts = Vec::new();

    if let Some(content) = message.get("content") {
        let text = openai_text_content(content);
        if !text.is_empty() {
            parts.push(Part::Text(text));
        }
    }

    if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) {
        for tool_call in tool_calls {
            let id = tool_call
                .get("id")
                .and_then(Value::as_str)
                .ok_or_else(|| Error::Parse("missing tool call id".to_string()))?;
            let function = tool_call
                .get("function")
                .ok_or_else(|| Error::Parse("missing tool call function".to_string()))?;
            let name = function
                .get("name")
                .and_then(Value::as_str)
                .ok_or_else(|| Error::Parse("missing tool call name".to_string()))?;
            let input = function
                .get("arguments")
                .and_then(Value::as_str)
                .ok_or_else(|| Error::Parse("missing tool call arguments".to_string()))?;

            parts.push(Part::ToolCall(crate::ToolCall {
                id: id.to_string(),
                name: name.to_string(),
                input: input.to_string(),
            }));
        }
    }

    Ok(ModelResponse {
        parts,
        finish_reason: openai_finish_reason(choice.get("finish_reason").and_then(Value::as_str)),
        usage: openai_usage(body),
        response_metadata: crate::metadata_with_provider("openai", model_id),
    })
}

fn openai_text_content(content: &Value) -> String {
    match content {
        Value::String(text) => text.clone(),
        Value::Array(parts) => parts
            .iter()
            .filter_map(|part| part.get("text").and_then(Value::as_str))
            .collect::<Vec<_>>()
            .join(""),
        _ => String::new(),
    }
}

fn openai_finish_reason(reason: Option<&str>) -> FinishReason {
    match reason {
        Some("tool_calls") => FinishReason::ToolCalls,
        Some("length") => FinishReason::Length,
        Some("stop") | None => FinishReason::Stop,
        _ => FinishReason::Error,
    }
}

pub(crate) fn openai_usage(body: &Value) -> Usage {
    let usage = body.get("usage");

    Usage {
        input_tokens: usage
            .and_then(|usage| usage.get("prompt_tokens"))
            .and_then(Value::as_u64)
            .unwrap_or_default() as usize,
        output_tokens: usage
            .and_then(|usage| usage.get("completion_tokens"))
            .and_then(Value::as_u64)
            .unwrap_or_default() as usize,
    }
}

fn openai_error_message(body: &Value) -> String {
    body.get("error")
        .and_then(|error| error.get("message"))
        .and_then(Value::as_str)
        .unwrap_or("unknown OpenAI error")
        .to_string()
}