systemprompt-ai 0.1.19

Core AI module for systemprompt.io
Documentation
use anyhow::Result;
use uuid::Uuid;

use crate::models::RequestStatus;
use crate::models::ai::{AiMessage, AiRequest, AiResponse, GenerateResponseParams};
use crate::models::tools::McpTool;
use crate::services::providers::{GenerationParams, ModelPricing, ToolGenerationParams};

use super::super::request_storage::StoreParams;
use super::service::AiService;

impl AiService {
    pub async fn generate_plan(
        &self,
        request: &AiRequest,
        available_tools: &[McpTool],
    ) -> Result<systemprompt_models::ai::PlanningResult> {
        let request_id = Uuid::new_v4();
        let start = std::time::Instant::now();
        let provider = self.get_provider(request.provider())?;
        let model = request.model();

        let base = GenerationParams::new(&request.messages, model, request.max_output_tokens());
        let base = request
            .sampling
            .as_ref()
            .map_or_else(|| base.clone(), |s| base.clone().with_sampling(s));
        let params = ToolGenerationParams::new(base, available_tools.to_vec());

        let result = provider.generate_with_tools(params).await;

        let latency_ms = start.elapsed().as_millis() as u64;

        match result {
            Ok((mut response, tool_calls)) => {
                response.request_id = request_id;
                response.latency_ms = latency_ms;
                response.tool_calls.clone_from(&tool_calls);
                let cost = self.estimate_cost(&response);
                self.storage.store(&StoreParams {
                    request,
                    response: &response,
                    context: &request.context,
                    status: RequestStatus::Completed,
                    error_message: None,
                    cost_microdollars: cost,
                });

                Ok(if tool_calls.is_empty() {
                    systemprompt_models::ai::PlanningResult::DirectResponse {
                        content: response.content,
                    }
                } else {
                    systemprompt_models::ai::PlanningResult::ToolCalls {
                        reasoning: response.content,
                        calls: tool_calls
                            .into_iter()
                            .map(|tc| systemprompt_models::ai::PlannedToolCall {
                                tool_name: tc.name,
                                arguments: tc.arguments,
                            })
                            .collect(),
                    }
                })
            },
            Err(e) => {
                self.store_error(request, request_id, latency_ms, &e);
                Err(e)
            },
        }
    }

    pub async fn generate_response(&self, params: GenerateResponseParams<'_>) -> Result<String> {
        let mut response_messages = params.messages;
        response_messages.push(AiMessage::user(format!(
            "## Tool Execution Complete\n\nThe following tools have been executed:\n\n{}\n\n## \
             Response Phase Instructions\n\nThis is the RESPONSE PHASE - your task is to \
             synthesize results and respond to the user.\n\n**CRITICAL: Do NOT attempt to call \
             any tools.** Tools are not available in this phase.\nAnalyze the tool results and \
             provide a helpful response to the user.",
            params.execution_summary
        )));

        let tool_config = params.context.tool_model_config();

        let provider = tool_config
            .and_then(|c| c.provider.as_deref())
            .or(params.provider)
            .unwrap_or_else(|| self.default_provider());
        let model = tool_config
            .and_then(|c| c.model.as_deref())
            .or(params.model)
            .unwrap_or_else(|| self.default_model());
        let max_output_tokens = tool_config
            .and_then(|c| c.max_output_tokens)
            .or(params.max_output_tokens)
            .unwrap_or_else(|| self.default_max_output_tokens());

        if tool_config.is_some() {
            tracing::debug!(
                provider,
                model,
                max_output_tokens,
                "Using tool_model_config in generate_response"
            );
        }

        let request = AiRequest::builder(
            response_messages,
            provider,
            model,
            max_output_tokens,
            params.context.clone(),
        )
        .build();

        let response = self.generate(&request).await?;
        Ok(response.content)
    }

    pub(super) fn estimate_cost(&self, response: &AiResponse) -> i64 {
        let input = f64::from(response.input_tokens.unwrap_or(0));
        let output = f64::from(response.output_tokens.unwrap_or(0));

        let pricing = self
            .providers
            .get(&response.provider)
            .map_or(ModelPricing::new(0.001, 0.001), |p| {
                p.get_pricing(&response.model)
            });

        let input_cost = (input / 1000.0) * f64::from(pricing.input_cost_per_1k);
        let output_cost = (output / 1000.0) * f64::from(pricing.output_cost_per_1k);

        ((input_cost + output_cost) * 1_000_000.0).round() as i64
    }
}