praxis_graph/types/
state.rs

1use crate::types::config::{LLMConfig, ContextPolicy};
2use crate::types::GraphOutput;
3use praxis_llm::{Message, ToolCall};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
8pub struct GraphState {
9    pub conversation_id: String,
10    pub run_id: String,
11    pub messages: Vec<Message>,
12    pub llm_config: LLMConfig,
13    pub variables: HashMap<String, serde_json::Value>,
14    #[allow(dead_code)]
15    pub last_outputs: Option<Vec<GraphOutput>>,
16}
17
18impl GraphState {
19    pub fn new(
20        conversation_id: String,
21        run_id: String,
22        messages: Vec<Message>,
23        llm_config: LLMConfig,
24    ) -> Self {
25        Self {
26            conversation_id,
27            run_id,
28            messages,
29            llm_config,
30            variables: HashMap::new(),
31            last_outputs: None,
32        }
33    }
34
35    pub fn from_input(input: GraphInput) -> Self {
36        Self {
37            conversation_id: input.conversation_id,
38            run_id: uuid::Uuid::new_v4().to_string(),
39            messages: input.messages,
40            llm_config: input.llm_config,
41            variables: HashMap::new(),
42            last_outputs: None,
43        }
44    }
45
46    pub fn last_message(&self) -> Option<&Message> {
47        self.messages.last()
48    }
49
50    pub fn add_message(&mut self, message: Message) {
51        self.messages.push(message);
52    }
53
54    pub fn has_pending_tool_calls(&self) -> bool {
55        if let Some(last_msg) = self.last_message() {
56            match last_msg {
57                Message::AI { tool_calls, .. } => tool_calls.is_some(),
58                _ => false,
59            }
60        } else {
61            false
62        }
63    }
64
65    pub fn get_pending_tool_calls(&self) -> Vec<ToolCall> {
66        if let Some(last_msg) = self.last_message() {
67            match last_msg {
68                Message::AI { tool_calls: Some(calls), .. } => calls.clone(),
69                _ => Vec::new(),
70            }
71        } else {
72            Vec::new()
73        }
74    }
75
76    pub fn add_tool_result(&mut self, tool_call_id: String, result: String) {
77        self.messages.push(Message::Tool {
78            tool_call_id,
79            content: praxis_llm::Content::text(result),
80        });
81    }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct GraphInput {
86    pub conversation_id: String,
87    pub messages: Vec<Message>,
88    pub llm_config: LLMConfig,
89    pub context_policy: ContextPolicy,
90}
91
92impl GraphInput {
93    pub fn new(
94        conversation_id: impl Into<String>,
95        messages: Vec<Message>,
96        llm_config: LLMConfig,
97    ) -> Self {
98        Self {
99            conversation_id: conversation_id.into(),
100            messages,
101            llm_config,
102            context_policy: ContextPolicy::default(),
103        }
104    }
105
106    pub fn with_context_policy(mut self, policy: ContextPolicy) -> Self {
107        self.context_policy = policy;
108        self
109    }
110}
111