systemprompt-ai 0.1.22

Core AI module for systemprompt.io
Documentation
use anyhow::{Result, anyhow};
use serde_json::json;
use std::time::Instant;
use tracing::warn;
use uuid::Uuid;

use crate::models::ai::AiResponse;
use crate::models::providers::openai::{
    OpenAiJsonSchema, OpenAiRequest, OpenAiResponse, OpenAiResponseFormat,
};
use crate::models::tools::ToolCall;
use crate::services::providers::{
    GenerationParams, SchemaGenerationParams, StructuredGenerationParams, ToolGenerationParams,
};
use systemprompt_identifiers::AiToolCallId;

use super::provider::OpenAiProvider;
use super::response_builder::build_response;
use super::{converters, reasoning};

pub async fn generate(
    provider: &OpenAiProvider,
    params: GenerationParams<'_>,
) -> Result<AiResponse> {
    let start = Instant::now();
    let request_id = Uuid::new_v4();

    let openai_messages: Vec<crate::models::providers::openai::OpenAiMessage> =
        params.messages.iter().map(Into::into).collect();

    let (temperature, top_p, presence_penalty, frequency_penalty) =
        params.sampling.map_or((None, None, None, None), |s| {
            (
                s.temperature,
                s.top_p,
                s.presence_penalty,
                s.frequency_penalty,
            )
        });

    let reasoning_config = reasoning::build_reasoning_config(params.model);

    let request = OpenAiRequest {
        model: params.model.to_string(),
        messages: openai_messages,
        temperature,
        top_p,
        presence_penalty,
        frequency_penalty,
        max_tokens: Some(params.max_output_tokens),
        tools: None,
        response_format: None,
        reasoning_effort: reasoning_config,
    };

    let response = provider
        .client
        .post(format!("{}/chat/completions", provider.endpoint))
        .bearer_auth(&provider.api_key)
        .json(&request)
        .send()
        .await?;

    if !response.status().is_success() {
        let error_text = response.text().await?;
        return Err(anyhow!("OpenAI API error: {error_text}"));
    }

    let openai_response: OpenAiResponse = response.json().await?;
    build_response(request_id, &openai_response, "openai", params.model, start)
}

pub async fn generate_with_tools(
    provider: &OpenAiProvider,
    params: ToolGenerationParams<'_>,
) -> Result<(AiResponse, Vec<ToolCall>)> {
    let start = Instant::now();
    let request_id = Uuid::new_v4();

    let openai_messages: Vec<crate::models::providers::openai::OpenAiMessage> =
        params.base.messages.iter().map(Into::into).collect();

    let openai_tools = converters::convert_tools(params.tools)?;

    let (temperature, top_p, presence_penalty, frequency_penalty) =
        params.base.sampling.map_or((None, None, None, None), |s| {
            (
                s.temperature,
                s.top_p,
                s.presence_penalty,
                s.frequency_penalty,
            )
        });

    let reasoning_config = reasoning::build_reasoning_config(params.base.model);

    let request = OpenAiRequest {
        model: params.base.model.to_string(),
        messages: openai_messages,
        temperature,
        top_p,
        presence_penalty,
        frequency_penalty,
        max_tokens: Some(params.base.max_output_tokens),
        tools: Some(openai_tools),
        response_format: None,
        reasoning_effort: reasoning_config,
    };

    let response = provider
        .client
        .post(format!("{}/chat/completions", provider.endpoint))
        .bearer_auth(&provider.api_key)
        .json(&request)
        .send()
        .await?;

    if !response.status().is_success() {
        let error_text = response.text().await?;
        return Err(anyhow!("OpenAI API error: {error_text}"));
    }

    let openai_response: OpenAiResponse = response.json().await?;

    let choice = openai_response
        .choices
        .first()
        .ok_or_else(|| anyhow!("No response from OpenAI"))?;

    let tool_calls = choice
        .message
        .tool_calls
        .clone()
        .unwrap_or_else(Vec::new)
        .into_iter()
        .map(|tc| {
            let arguments = serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
                .unwrap_or_else(|e| {
                    warn!(
                        error = %e,
                        tool_name = %tc.function.name,
                        raw_arguments = %tc.function.arguments,
                        "Failed to parse OpenAI tool arguments"
                    );
                    json!({})
                });

            ToolCall {
                ai_tool_call_id: AiToolCallId::from(tc.id),
                name: tc.function.name,
                arguments,
            }
        })
        .collect();

    let ai_response = build_response(
        request_id,
        &openai_response,
        "openai",
        params.base.model,
        start,
    )?;
    Ok((ai_response, tool_calls))
}

pub async fn generate_structured(
    provider: &OpenAiProvider,
    params: StructuredGenerationParams<'_>,
) -> Result<AiResponse> {
    let start = Instant::now();
    let request_id = Uuid::new_v4();

    let openai_messages: Vec<crate::models::providers::openai::OpenAiMessage> =
        params.base.messages.iter().map(Into::into).collect();

    let (temperature, top_p, presence_penalty, frequency_penalty) =
        params.base.sampling.map_or((None, None, None, None), |s| {
            (
                s.temperature,
                s.top_p,
                s.presence_penalty,
                s.frequency_penalty,
            )
        });

    let reasoning_config = reasoning::build_reasoning_config(params.base.model);

    let request = OpenAiRequest {
        model: params.base.model.to_string(),
        messages: openai_messages,
        temperature,
        top_p,
        presence_penalty,
        frequency_penalty,
        max_tokens: Some(params.base.max_output_tokens),
        tools: None,
        response_format: converters::convert_response_format(params.response_format)?,
        reasoning_effort: reasoning_config,
    };

    let response = provider
        .client
        .post(format!("{}/chat/completions", provider.endpoint))
        .bearer_auth(&provider.api_key)
        .json(&request)
        .send()
        .await?;

    if !response.status().is_success() {
        let error_text = response.text().await?;
        return Err(anyhow!("OpenAI API error: {error_text}"));
    }

    let openai_response: OpenAiResponse = response.json().await?;
    build_response(
        request_id,
        &openai_response,
        "openai",
        params.base.model,
        start,
    )
}

pub async fn generate_with_schema(
    provider: &OpenAiProvider,
    params: SchemaGenerationParams<'_>,
) -> Result<AiResponse> {
    let start = Instant::now();
    let request_id = Uuid::new_v4();

    let openai_messages: Vec<crate::models::providers::openai::OpenAiMessage> =
        params.base.messages.iter().map(Into::into).collect();

    let (temperature, top_p) = params
        .base
        .sampling
        .map_or((None, None), |s| (s.temperature, s.top_p));

    let reasoning_config = reasoning::build_reasoning_config(params.base.model);

    let request = OpenAiRequest {
        model: params.base.model.to_string(),
        messages: openai_messages,
        temperature,
        top_p,
        presence_penalty: None,
        frequency_penalty: None,
        max_tokens: Some(params.base.max_output_tokens),
        tools: None,
        response_format: Some(OpenAiResponseFormat::JsonSchema {
            json_schema: OpenAiJsonSchema {
                name: "structured_output".to_string(),
                schema: params.response_schema,
                strict: Some(true),
            },
        }),
        reasoning_effort: reasoning_config,
    };

    let response = provider
        .client
        .post(format!("{}/chat/completions", provider.endpoint))
        .header("Authorization", format!("Bearer {}", provider.api_key))
        .json(&request)
        .send()
        .await?;

    if !response.status().is_success() {
        let error_text = response.text().await?;
        return Err(anyhow!("OpenAI API error: {error_text}"));
    }

    let openai_response: OpenAiResponse = response.json().await?;
    build_response(
        request_id,
        &openai_response,
        "openai",
        params.base.model,
        start,
    )
}