systemprompt-models 0.1.18

Shared data models and types for systemprompt.io OS
Documentation
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PlanningResult {
    DirectResponse {
        content: String,
    },
    ToolCalls {
        reasoning: String,
        calls: Vec<PlannedToolCall>,
    },
}

impl PlanningResult {
    pub fn direct_response(content: impl Into<String>) -> Self {
        Self::DirectResponse {
            content: content.into(),
        }
    }

    pub fn tool_calls(reasoning: impl Into<String>, calls: Vec<PlannedToolCall>) -> Self {
        Self::ToolCalls {
            reasoning: reasoning.into(),
            calls,
        }
    }

    pub const fn is_direct(&self) -> bool {
        matches!(self, Self::DirectResponse { .. })
    }

    pub const fn is_tool_calls(&self) -> bool {
        matches!(self, Self::ToolCalls { .. })
    }

    pub fn tool_count(&self) -> usize {
        match self {
            Self::DirectResponse { .. } => 0,
            Self::ToolCalls { calls, .. } => calls.len(),
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlannedToolCall {
    pub tool_name: String,
    pub arguments: Value,
}

impl PlannedToolCall {
    pub fn new(tool_name: impl Into<String>, arguments: Value) -> Self {
        Self {
            tool_name: tool_name.into(),
            arguments,
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallResult {
    pub tool_name: String,
    pub arguments: Value,
    pub success: bool,
    pub output: Value,
    pub error: Option<String>,
    pub duration_ms: u64,
}

impl ToolCallResult {
    pub const fn success(
        tool_name: String,
        arguments: Value,
        output: Value,
        duration_ms: u64,
    ) -> Self {
        Self {
            tool_name,
            arguments,
            success: true,
            output,
            error: None,
            duration_ms,
        }
    }

    pub fn failure(
        tool_name: String,
        arguments: Value,
        error: impl Into<String>,
        duration_ms: u64,
    ) -> Self {
        Self {
            tool_name,
            arguments,
            success: false,
            output: Value::Null,
            error: Some(error.into()),
            duration_ms,
        }
    }
}

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ExecutionState {
    pub results: Vec<ToolCallResult>,
    pub halted: bool,
    pub halt_reason: Option<String>,
}

impl ExecutionState {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn add_result(&mut self, result: ToolCallResult) {
        if !result.success && !self.halted {
            self.halted = true;
            self.halt_reason.clone_from(&result.error);
        }
        self.results.push(result);
    }

    pub fn successful_results(&self) -> Vec<&ToolCallResult> {
        self.results.iter().filter(|r| r.success).collect()
    }

    pub fn failed_results(&self) -> Vec<&ToolCallResult> {
        self.results.iter().filter(|r| !r.success).collect()
    }

    pub fn total_duration_ms(&self) -> u64 {
        self.results.iter().map(|r| r.duration_ms).sum()
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemplateRef {
    pub tool_index: usize,
    pub field_path: Vec<String>,
}

impl TemplateRef {
    pub fn parse(template: &str) -> Option<Self> {
        let re = Regex::new(r"^\$(\d+)\.output\.(.+)$").ok()?;
        let caps = re.captures(template)?;

        let tool_index = caps.get(1)?.as_str().parse().ok()?;
        let path = caps.get(2)?.as_str();
        let field_path = path.split('.').map(String::from).collect();

        Some(Self {
            tool_index,
            field_path,
        })
    }

    pub fn format(&self) -> String {
        format!("${}.output.{}", self.tool_index, self.field_path.join("."))
    }
}