llmkit-openai 0.1.0

OpenAI (GPT-4o, o1) provider adapter for llmkit-rs
Documentation
//! Mapping between llmkit types and OpenAI `/v1/chat/completions` wire types.

use llmkit_core::{
    ChatRequest, ChatResponse, FinishReason, LlmError, LlmResult, Message, MessageContent, Role,
    ToolCall, ToolChoice,
};
use serde_json::json;

use crate::types::*;

pub(crate) fn build_request(
    req: &ChatRequest,
    model: String,
    stream: bool,
) -> ChatCompletionRequest {
    let mut messages = Vec::with_capacity(req.messages.len() + 1);
    if let Some(system) = &req.system {
        messages.push(WireMessage {
            role: "system".into(),
            content: Some(system.clone()),
            tool_calls: None,
            tool_call_id: None,
        });
    }
    for m in &req.messages {
        messages.push(map_message(m));
    }

    let tools = req.tools.as_ref().map(|ts| {
        ts.iter()
            .map(|t| WireTool {
                kind: "function",
                function: WireFunction {
                    name: t.name.clone(),
                    description: t.description.clone(),
                    parameters: t.input_schema.clone(),
                },
            })
            .collect()
    });

    let tool_choice = req.tool_choice.as_ref().map(map_tool_choice);

    ChatCompletionRequest {
        model,
        messages,
        max_tokens: req.max_tokens,
        temperature: req.temperature,
        stop: req.stop.clone(),
        tools,
        tool_choice,
        stream,
        stream_options: stream.then_some(StreamOptions { include_usage: true }),
    }
}

fn map_message(m: &Message) -> WireMessage {
    match (&m.role, &m.content) {
        (_, MessageContent::ToolResult { tool_use_id, content }) => WireMessage {
            role: "tool".into(),
            content: Some(content.clone()),
            tool_calls: None,
            tool_call_id: Some(tool_use_id.clone()),
        },
        (_, MessageContent::ToolUse { id, name, input }) => WireMessage {
            role: "assistant".into(),
            content: None,
            tool_calls: Some(vec![WireToolCall {
                id: id.clone(),
                kind: "function",
                function: WireToolCallFunction {
                    name: name.clone(),
                    arguments: input.to_string(),
                },
            }]),
            tool_call_id: None,
        },
        (role, content) => WireMessage {
            role: role_str(*role).into(),
            content: content.as_text().or(Some(String::new())),
            tool_calls: None,
            tool_call_id: None,
        },
    }
}

fn role_str(role: Role) -> &'static str {
    match role {
        Role::User => "user",
        Role::Assistant => "assistant",
        Role::System => "system",
        Role::Tool => "tool",
    }
}

fn map_tool_choice(choice: &ToolChoice) -> serde_json::Value {
    match choice {
        ToolChoice::Auto => json!("auto"),
        ToolChoice::Any => json!("required"),
        ToolChoice::None => json!("none"),
        ToolChoice::Tool(name) => json!({ "type": "function", "function": { "name": name } }),
    }
}

pub(crate) fn map_finish_reason(reason: Option<&str>, has_tools: bool) -> FinishReason {
    match reason {
        Some("stop") => FinishReason::Stop,
        Some("length") => FinishReason::MaxTokens,
        Some("tool_calls") | Some("function_call") => FinishReason::ToolUse,
        Some("content_filter") => FinishReason::ContentFilter,
        Some(other) => FinishReason::Other(other.to_string()),
        None if has_tools => FinishReason::ToolUse,
        None => FinishReason::Stop,
    }
}

pub(crate) fn map_response(
    resp: ChatCompletionResponse,
    latency_ms: u64,
) -> LlmResult<ChatResponse> {
    let choice = resp
        .choices
        .into_iter()
        .next()
        .ok_or_else(|| LlmError::Provider { status: 200, message: "no choices returned".into() })?;

    let mut tool_calls = Vec::new();
    if let Some(calls) = choice.message.tool_calls {
        for c in calls {
            let input = serde_json::from_str(&c.function.arguments)
                .unwrap_or_else(|_| json!(c.function.arguments));
            tool_calls.push(ToolCall::new(c.id, c.function.name, input));
        }
    }

    let text = choice.message.content.unwrap_or_default();
    let finish_reason = map_finish_reason(choice.finish_reason.as_deref(), !tool_calls.is_empty());
    let usage = resp
        .usage
        .map(|u| llmkit_core::TokenUsage::new(u.prompt_tokens, u.completion_tokens))
        .unwrap_or_default();

    Ok(ChatResponse {
        id: resp.id,
        provider: "openai".into(),
        model: resp.model,
        message: Message::assistant(text),
        finish_reason,
        tool_calls,
        usage,
        cost: None,
        latency_ms,
    })
}