praxis_types/
state.rs

1use crate::config::{LLMConfig, ContextPolicy};
2use praxis_llm::{Message, ToolCall};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GraphState {
8    pub conversation_id: String,
9    pub run_id: String,
10    pub messages: Vec<Message>,
11    pub llm_config: LLMConfig,
12    pub variables: HashMap<String, serde_json::Value>,
13}
14
15impl GraphState {
16    pub fn new(
17        conversation_id: String,
18        run_id: String,
19        messages: Vec<Message>,
20        llm_config: LLMConfig,
21    ) -> Self {
22        Self {
23            conversation_id,
24            run_id,
25            messages,
26            llm_config,
27            variables: HashMap::new(),
28        }
29    }
30
31    pub fn from_input(input: GraphInput) -> Self {
32        let mut messages = Vec::new();
33        
34        // TODO: In a real implementation, we'd fetch history from DB here
35        // For now, just use the last message
36        messages.push(input.last_message);
37
38        Self {
39            conversation_id: input.conversation_id,
40            run_id: uuid::Uuid::new_v4().to_string(),
41            messages,
42            llm_config: input.llm_config,
43            variables: HashMap::new(),
44        }
45    }
46
47    pub fn last_message(&self) -> Option<&Message> {
48        self.messages.last()
49    }
50
51    pub fn add_message(&mut self, message: Message) {
52        self.messages.push(message);
53    }
54
55    pub fn has_pending_tool_calls(&self) -> bool {
56        if let Some(last_msg) = self.last_message() {
57            match last_msg {
58                Message::AI { tool_calls, .. } => tool_calls.is_some(),
59                _ => false,
60            }
61        } else {
62            false
63        }
64    }
65
66    pub fn get_pending_tool_calls(&self) -> Vec<ToolCall> {
67        if let Some(last_msg) = self.last_message() {
68            match last_msg {
69                Message::AI { tool_calls: Some(calls), .. } => calls.clone(),
70                _ => Vec::new(),
71            }
72        } else {
73            Vec::new()
74        }
75    }
76
77    pub fn add_tool_result(&mut self, tool_call_id: String, result: String) {
78        self.messages.push(Message::Tool {
79            tool_call_id,
80            content: praxis_llm::Content::text(result),
81        });
82    }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct GraphInput {
87    pub conversation_id: String,
88    pub last_message: Message,
89    pub llm_config: LLMConfig,
90    pub context_policy: ContextPolicy,
91}
92
93impl GraphInput {
94    pub fn new(
95        conversation_id: impl Into<String>,
96        last_message: Message,
97        llm_config: LLMConfig,
98    ) -> Self {
99        Self {
100            conversation_id: conversation_id.into(),
101            last_message,
102            llm_config,
103            context_policy: ContextPolicy::default(),
104        }
105    }
106
107    pub fn with_context_policy(mut self, policy: ContextPolicy) -> Self {
108        self.context_policy = policy;
109        self
110    }
111}
112