systemprompt-ai 0.1.19

Core AI module for systemprompt.io
Documentation
use anyhow::Result;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;

use crate::models::ai::{AiResponse, SamplingParams, SearchGroundedResponse, StreamChunk};
use crate::models::tools::ToolCall;
use crate::services::providers::{
    AiProvider, GenerationParams, ModelPricing, SchemaGenerationParams, SearchGenerationParams,
    StructuredGenerationParams, ToolGenerationParams,
};
use crate::services::schema::ProviderCapabilities;

use super::provider::OpenAiProvider;
use super::{converters, generation, search};

#[async_trait]
impl AiProvider for OpenAiProvider {
    fn name(&self) -> &'static str {
        "openai"
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    fn capabilities(&self) -> ProviderCapabilities {
        ProviderCapabilities::openai()
    }

    fn supports_model(&self, model: &str) -> bool {
        matches!(
            model,
            "gpt-4.1"
                | "gpt-4.1-mini"
                | "gpt-4.1-nano"
                | "o4-mini"
                | "o3"
                | "o3-mini"
                | "gpt-4o"
                | "gpt-4o-mini"
                | "gpt-4-turbo"
                | "gpt-4"
                | "gpt-3.5-turbo"
                | "o1"
                | "o1-mini"
                | "o1-preview"
        )
    }

    fn supports_sampling(&self, _sampling: Option<&SamplingParams>) -> bool {
        true
    }

    fn default_model(&self) -> &'static str {
        "gpt-4.1"
    }

    fn get_pricing(&self, model: &str) -> ModelPricing {
        match model {
            "gpt-4.1" => ModelPricing::new(0.002, 0.008),
            "gpt-4.1-mini" => ModelPricing::new(0.0004, 0.0016),
            "gpt-4.1-nano" => ModelPricing::new(0.0001, 0.0004),
            "o4-mini" | "o3-mini" => ModelPricing::new(0.0011, 0.0044),
            "o3" => ModelPricing::new(0.01, 0.04),
            "gpt-4o-mini" | "gpt-4o-mini-2024-07-18" => ModelPricing::new(0.00015, 0.0006),
            "gpt-4" | "gpt-4-turbo" | "gpt-4-turbo-preview" => ModelPricing::new(0.01, 0.03),
            "gpt-3.5-turbo" | "gpt-3.5-turbo-0125" => ModelPricing::new(0.0005, 0.0015),
            "o1" | "o1-2024-12-17" => ModelPricing::new(0.015, 0.06),
            "o1-mini" | "o1-mini-2024-09-12" => ModelPricing::new(0.003, 0.012),
            _ => ModelPricing::new(0.0025, 0.01),
        }
    }

    async fn generate(&self, params: GenerationParams<'_>) -> Result<AiResponse> {
        generation::generate(self, params).await
    }

    async fn generate_with_tools(
        &self,
        params: ToolGenerationParams<'_>,
    ) -> Result<(AiResponse, Vec<ToolCall>)> {
        generation::generate_with_tools(self, params).await
    }

    async fn generate_structured(
        &self,
        params: StructuredGenerationParams<'_>,
    ) -> Result<AiResponse> {
        generation::generate_structured(self, params).await
    }

    async fn generate_with_schema(&self, params: SchemaGenerationParams<'_>) -> Result<AiResponse> {
        generation::generate_with_schema(self, params).await
    }

    fn supports_json_mode(&self) -> bool {
        true
    }

    fn supports_structured_output(&self) -> bool {
        true
    }

    fn supports_streaming(&self) -> bool {
        true
    }

    async fn generate_stream(
        &self,
        params: GenerationParams<'_>,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
        self.create_stream_request(params, None).await
    }

    async fn generate_with_tools_stream(
        &self,
        params: ToolGenerationParams<'_>,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>> {
        let openai_tools = converters::convert_tools(params.tools)?;
        self.create_stream_request(params.base, Some(openai_tools))
            .await
    }

    fn supports_google_search(&self) -> bool {
        self.web_search_enabled
    }

    async fn generate_with_google_search(
        &self,
        params: SearchGenerationParams<'_>,
    ) -> Result<SearchGroundedResponse> {
        let search_params = search::SearchParams::new(
            params.base.messages,
            params.base.max_output_tokens,
            params.base.model,
        );
        let search_params = if let Some(sampling) = params.base.sampling {
            search_params.with_sampling(sampling)
        } else {
            search_params
        };
        search::generate_with_web_search(self, search_params).await
    }
}