1use crate::config::{LLMConfig, ContextPolicy};
2use praxis_llm::{Message, ToolCall};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GraphState {
8 pub conversation_id: String,
9 pub run_id: String,
10 pub messages: Vec<Message>,
11 pub llm_config: LLMConfig,
12 pub variables: HashMap<String, serde_json::Value>,
13}
14
15impl GraphState {
16 pub fn new(
17 conversation_id: String,
18 run_id: String,
19 messages: Vec<Message>,
20 llm_config: LLMConfig,
21 ) -> Self {
22 Self {
23 conversation_id,
24 run_id,
25 messages,
26 llm_config,
27 variables: HashMap::new(),
28 }
29 }
30
31 pub fn from_input(input: GraphInput) -> Self {
32 let mut messages = Vec::new();
33
34 messages.push(input.last_message);
37
38 Self {
39 conversation_id: input.conversation_id,
40 run_id: uuid::Uuid::new_v4().to_string(),
41 messages,
42 llm_config: input.llm_config,
43 variables: HashMap::new(),
44 }
45 }
46
47 pub fn last_message(&self) -> Option<&Message> {
48 self.messages.last()
49 }
50
51 pub fn add_message(&mut self, message: Message) {
52 self.messages.push(message);
53 }
54
55 pub fn has_pending_tool_calls(&self) -> bool {
56 if let Some(last_msg) = self.last_message() {
57 match last_msg {
58 Message::AI { tool_calls, .. } => tool_calls.is_some(),
59 _ => false,
60 }
61 } else {
62 false
63 }
64 }
65
66 pub fn get_pending_tool_calls(&self) -> Vec<ToolCall> {
67 if let Some(last_msg) = self.last_message() {
68 match last_msg {
69 Message::AI { tool_calls: Some(calls), .. } => calls.clone(),
70 _ => Vec::new(),
71 }
72 } else {
73 Vec::new()
74 }
75 }
76
77 pub fn add_tool_result(&mut self, tool_call_id: String, result: String) {
78 self.messages.push(Message::Tool {
79 tool_call_id,
80 content: praxis_llm::Content::text(result),
81 });
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct GraphInput {
87 pub conversation_id: String,
88 pub last_message: Message,
89 pub llm_config: LLMConfig,
90 pub context_policy: ContextPolicy,
91}
92
93impl GraphInput {
94 pub fn new(
95 conversation_id: impl Into<String>,
96 last_message: Message,
97 llm_config: LLMConfig,
98 ) -> Self {
99 Self {
100 conversation_id: conversation_id.into(),
101 last_message,
102 llm_config,
103 context_policy: ContextPolicy::default(),
104 }
105 }
106
107 pub fn with_context_policy(mut self, policy: ContextPolicy) -> Self {
108 self.context_policy = policy;
109 self
110 }
111}
112