praxis_graph/types/
state.rs1use crate::types::config::{LLMConfig, ContextPolicy};
2use crate::types::GraphOutput;
3use praxis_llm::{Message, ToolCall};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
8pub struct GraphState {
9 pub conversation_id: String,
10 pub run_id: String,
11 pub messages: Vec<Message>,
12 pub llm_config: LLMConfig,
13 pub variables: HashMap<String, serde_json::Value>,
14 #[allow(dead_code)]
15 pub last_outputs: Option<Vec<GraphOutput>>,
16}
17
18impl GraphState {
19 pub fn new(
20 conversation_id: String,
21 run_id: String,
22 messages: Vec<Message>,
23 llm_config: LLMConfig,
24 ) -> Self {
25 Self {
26 conversation_id,
27 run_id,
28 messages,
29 llm_config,
30 variables: HashMap::new(),
31 last_outputs: None,
32 }
33 }
34
35 pub fn from_input(input: GraphInput) -> Self {
36 Self {
37 conversation_id: input.conversation_id,
38 run_id: uuid::Uuid::new_v4().to_string(),
39 messages: input.messages,
40 llm_config: input.llm_config,
41 variables: HashMap::new(),
42 last_outputs: None,
43 }
44 }
45
46 pub fn last_message(&self) -> Option<&Message> {
47 self.messages.last()
48 }
49
50 pub fn add_message(&mut self, message: Message) {
51 self.messages.push(message);
52 }
53
54 pub fn has_pending_tool_calls(&self) -> bool {
55 if let Some(last_msg) = self.last_message() {
56 match last_msg {
57 Message::AI { tool_calls, .. } => tool_calls.is_some(),
58 _ => false,
59 }
60 } else {
61 false
62 }
63 }
64
65 pub fn get_pending_tool_calls(&self) -> Vec<ToolCall> {
66 if let Some(last_msg) = self.last_message() {
67 match last_msg {
68 Message::AI { tool_calls: Some(calls), .. } => calls.clone(),
69 _ => Vec::new(),
70 }
71 } else {
72 Vec::new()
73 }
74 }
75
76 pub fn add_tool_result(&mut self, tool_call_id: String, result: String) {
77 self.messages.push(Message::Tool {
78 tool_call_id,
79 content: praxis_llm::Content::text(result),
80 });
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct GraphInput {
86 pub conversation_id: String,
87 pub messages: Vec<Message>,
88 pub llm_config: LLMConfig,
89 pub context_policy: ContextPolicy,
90}
91
92impl GraphInput {
93 pub fn new(
94 conversation_id: impl Into<String>,
95 messages: Vec<Message>,
96 llm_config: LLMConfig,
97 ) -> Self {
98 Self {
99 conversation_id: conversation_id.into(),
100 messages,
101 llm_config,
102 context_policy: ContextPolicy::default(),
103 }
104 }
105
106 pub fn with_context_policy(mut self, policy: ContextPolicy) -> Self {
107 self.context_policy = policy;
108 self
109 }
110}
111