praxis_graph/types/
state.rs

1use crate::types::config::{LLMConfig, ContextPolicy};
2use praxis_llm::{Message, ToolCall};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GraphState {
8    pub conversation_id: String,
9    pub run_id: String,
10    pub messages: Vec<Message>,
11    pub llm_config: LLMConfig,
12    pub variables: HashMap<String, serde_json::Value>,
13}
14
15impl GraphState {
16    pub fn new(
17        conversation_id: String,
18        run_id: String,
19        messages: Vec<Message>,
20        llm_config: LLMConfig,
21    ) -> Self {
22        Self {
23            conversation_id,
24            run_id,
25            messages,
26            llm_config,
27            variables: HashMap::new(),
28        }
29    }
30
31    pub fn from_input(input: GraphInput) -> Self {
32        Self {
33            conversation_id: input.conversation_id,
34            run_id: uuid::Uuid::new_v4().to_string(),
35            messages: input.messages,
36            llm_config: input.llm_config,
37            variables: HashMap::new(),
38        }
39    }
40
41    pub fn last_message(&self) -> Option<&Message> {
42        self.messages.last()
43    }
44
45    pub fn add_message(&mut self, message: Message) {
46        self.messages.push(message);
47    }
48
49    pub fn has_pending_tool_calls(&self) -> bool {
50        if let Some(last_msg) = self.last_message() {
51            match last_msg {
52                Message::AI { tool_calls, .. } => tool_calls.is_some(),
53                _ => false,
54            }
55        } else {
56            false
57        }
58    }
59
60    pub fn get_pending_tool_calls(&self) -> Vec<ToolCall> {
61        if let Some(last_msg) = self.last_message() {
62            match last_msg {
63                Message::AI { tool_calls: Some(calls), .. } => calls.clone(),
64                _ => Vec::new(),
65            }
66        } else {
67            Vec::new()
68        }
69    }
70
71    pub fn add_tool_result(&mut self, tool_call_id: String, result: String) {
72        self.messages.push(Message::Tool {
73            tool_call_id,
74            content: praxis_llm::Content::text(result),
75        });
76    }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct GraphInput {
81    pub conversation_id: String,
82    pub messages: Vec<Message>,
83    pub llm_config: LLMConfig,
84    pub context_policy: ContextPolicy,
85}
86
87impl GraphInput {
88    pub fn new(
89        conversation_id: impl Into<String>,
90        messages: Vec<Message>,
91        llm_config: LLMConfig,
92    ) -> Self {
93        Self {
94            conversation_id: conversation_id.into(),
95            messages,
96            llm_config,
97            context_policy: ContextPolicy::default(),
98        }
99    }
100
101    pub fn with_context_policy(mut self, policy: ContextPolicy) -> Self {
102        self.context_policy = policy;
103        self
104    }
105}
106