1use serde::{Deserialize, Serialize};
13use std::collections::HashSet;
14use std::time::Instant;
15use tracing::{debug, info, warn};
16
17use super::{serialize_for_log, ModeCore};
18use crate::config::Config;
19use crate::error::{AppResult, ToolError};
20use crate::langbase::{LangbaseClient, Message, PipeRequest};
21use crate::prompts::{
22 GOT_AGGREGATE_PROMPT, GOT_GENERATE_PROMPT, GOT_REFINE_PROMPT, GOT_SCORE_PROMPT,
23};
24use crate::storage::{
25 EdgeType, GraphEdge, GraphNode, Invocation, NodeType, SqliteStorage, Storage,
26};
27
28#[cfg(test)]
29#[path = "got_tests.rs"]
30mod got_tests;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct GotConfig {
35 #[serde(default = "default_max_nodes")]
37 pub max_nodes: usize,
38 #[serde(default = "default_max_depth")]
40 pub max_depth: usize,
41 #[serde(default = "default_k")]
43 pub default_k: usize,
44 #[serde(default = "default_prune_threshold")]
46 pub prune_threshold: f64,
47}
48
49fn default_max_nodes() -> usize {
50 100
51}
52
53fn default_max_depth() -> usize {
54 10
55}
56
57fn default_k() -> usize {
58 3
59}
60
61fn default_prune_threshold() -> f64 {
62 0.3
63}
64
65impl Default for GotConfig {
66 fn default() -> Self {
67 Self {
68 max_nodes: default_max_nodes(),
69 max_depth: default_max_depth(),
70 default_k: default_k(),
71 prune_threshold: default_prune_threshold(),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct GotInitParams {
83 pub content: String,
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub problem: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub session_id: Option<String>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub config: Option<GotConfig>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct GotInitResult {
99 pub session_id: String,
101 pub root_node_id: String,
103 pub content: String,
105 pub config: GotConfig,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct GotGenerateParams {
116 pub session_id: String,
118 #[serde(skip_serializing_if = "Option::is_none")]
120 pub node_id: Option<String>,
121 #[serde(default = "default_k")]
123 pub k: usize,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub problem: Option<String>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct GeneratedContinuation {
132 pub node_id: String,
134 pub content: String,
136 pub confidence: f64,
138 pub novelty: f64,
140 pub rationale: String,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct GotGenerateResult {
147 pub session_id: String,
149 pub source_node_id: String,
151 pub continuations: Vec<GeneratedContinuation>,
153 pub count: usize,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159struct GenerateResponse {
160 continuations: Vec<ContinuationItem>,
161 #[serde(default)]
162 metadata: Option<serde_json::Value>,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166struct ContinuationItem {
167 thought: String,
168 #[serde(default = "default_confidence")]
169 confidence: f64,
170 #[serde(default)]
171 novelty: f64,
172 #[serde(default)]
173 rationale: String,
174}
175
176fn default_confidence() -> f64 {
177 0.7
178}
179
180impl GenerateResponse {
181 fn from_completion(completion: &str) -> Result<Self, ToolError> {
183 serde_json::from_str::<GenerateResponse>(completion).map_err(|e| {
184 let preview: String = completion.chars().take(200).collect();
185 ToolError::ParseFailed {
186 mode: "got.generate".to_string(),
187 message: format!("JSON parse error: {} | Response preview: {}", e, preview),
188 }
189 })
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct GotScoreParams {
200 pub session_id: String,
202 pub node_id: String,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub problem: Option<String>,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ScoreBreakdown {
212 pub relevance: f64,
214 pub validity: f64,
216 pub depth: f64,
218 pub novelty: f64,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct GotScoreResult {
225 pub session_id: String,
227 pub node_id: String,
229 pub overall_score: f64,
231 pub breakdown: ScoreBreakdown,
233 pub is_terminal_candidate: bool,
235 pub rationale: String,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241struct ScoreResponse {
242 overall_score: f64,
243 breakdown: ScoreBreakdownResponse,
244 #[serde(default)]
245 is_terminal_candidate: bool,
246 #[serde(default)]
247 rationale: String,
248 #[serde(default)]
249 metadata: Option<serde_json::Value>,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253struct ScoreBreakdownResponse {
254 #[serde(default = "default_score")]
255 relevance: f64,
256 #[serde(default = "default_score")]
257 validity: f64,
258 #[serde(default = "default_score")]
259 depth: f64,
260 #[serde(default = "default_score")]
261 novelty: f64,
262}
263
264fn default_score() -> f64 {
265 0.5
266}
267
268impl ScoreResponse {
269 fn from_completion(completion: &str) -> Result<Self, ToolError> {
271 serde_json::from_str::<ScoreResponse>(completion).map_err(|e| {
272 let preview: String = completion.chars().take(200).collect();
273 ToolError::ParseFailed {
274 mode: "got.score".to_string(),
275 message: format!("JSON parse error: {} | Response preview: {}", e, preview),
276 }
277 })
278 }
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct GotAggregateParams {
288 pub session_id: String,
290 pub node_ids: Vec<String>,
292 #[serde(skip_serializing_if = "Option::is_none")]
294 pub problem: Option<String>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct GotAggregateResult {
300 pub session_id: String,
302 pub aggregated_node_id: String,
304 pub content: String,
306 pub confidence: f64,
308 pub source_nodes: Vec<String>,
310 pub synthesis_approach: String,
312 pub conflicts_resolved: Vec<String>,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318struct AggregateResponse {
319 aggregated_thought: String,
320 #[serde(default = "default_confidence")]
321 confidence: f64,
322 #[serde(default)]
323 sources_used: Vec<String>,
324 #[serde(default)]
325 synthesis_approach: String,
326 #[serde(default)]
327 conflicts_resolved: Vec<String>,
328 #[serde(default)]
329 metadata: Option<serde_json::Value>,
330}
331
332impl AggregateResponse {
333 fn from_completion(completion: &str) -> Result<Self, ToolError> {
335 serde_json::from_str::<AggregateResponse>(completion).map_err(|e| {
336 let preview: String = completion.chars().take(200).collect();
337 ToolError::ParseFailed {
338 mode: "got.aggregate".to_string(),
339 message: format!("JSON parse error: {} | Response preview: {}", e, preview),
340 }
341 })
342 }
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct GotRefineParams {
352 pub session_id: String,
354 pub node_id: String,
356 #[serde(skip_serializing_if = "Option::is_none")]
358 pub problem: Option<String>,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct GotRefineResult {
364 pub session_id: String,
366 pub original_node_id: String,
368 pub refined_node_id: String,
370 pub content: String,
372 pub confidence: f64,
374 pub improvements_made: Vec<String>,
376 pub quality_delta: f64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
382struct RefineResponse {
383 refined_thought: String,
384 #[serde(default = "default_confidence")]
385 confidence: f64,
386 #[serde(default)]
387 improvements_made: Vec<String>,
388 #[serde(default)]
389 aspects_unchanged: Vec<String>,
390 #[serde(default)]
391 quality_delta: f64,
392 #[serde(default)]
393 metadata: Option<serde_json::Value>,
394}
395
396impl RefineResponse {
397 fn from_completion(completion: &str) -> Result<Self, ToolError> {
399 serde_json::from_str::<RefineResponse>(completion).map_err(|e| {
400 let preview: String = completion.chars().take(200).collect();
401 ToolError::ParseFailed {
402 mode: "got.refine".to_string(),
403 message: format!("JSON parse error: {} | Response preview: {}", e, preview),
404 }
405 })
406 }
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
415pub struct GotPruneParams {
416 pub session_id: String,
418 #[serde(skip_serializing_if = "Option::is_none")]
420 pub threshold: Option<f64>,
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct GotPruneResult {
426 pub session_id: String,
428 pub pruned_count: usize,
430 pub remaining_count: usize,
432 pub threshold_used: f64,
434 pub pruned_node_ids: Vec<String>,
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct GotFinalizeParams {
445 pub session_id: String,
447 #[serde(default)]
449 pub terminal_node_ids: Vec<String>,
450}
451
452#[derive(Debug, Clone, Serialize, Deserialize)]
454pub struct TerminalConclusion {
455 pub node_id: String,
457 pub content: String,
459 pub score: Option<f64>,
461 pub depth: i32,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct GotFinalizeResult {
468 pub session_id: String,
470 pub terminal_count: usize,
472 pub conclusions: Vec<TerminalConclusion>,
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct GotGetStateParams {
483 pub session_id: String,
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct GotStateResult {
490 pub session_id: String,
492 pub total_nodes: usize,
494 pub active_nodes: usize,
496 pub terminal_nodes: usize,
498 pub total_edges: usize,
500 pub max_depth: i32,
502 pub root_node_ids: Vec<String>,
504 pub active_node_ids: Vec<String>,
506 pub terminal_node_ids: Vec<String>,
508}
509
510#[derive(Clone)]
516pub struct GotMode {
517 core: ModeCore,
519 got_pipe: String,
521 config: GotConfig,
523}
524
525impl GotMode {
526 pub fn new(storage: SqliteStorage, langbase: LangbaseClient, config: &Config) -> Self {
528 let got_config = config
529 .pipes
530 .got
531 .as_ref()
532 .map(|g| GotConfig {
533 max_nodes: g.max_nodes.unwrap_or_else(default_max_nodes),
534 max_depth: g.max_depth.unwrap_or_else(default_max_depth),
535 default_k: g.default_k.unwrap_or_else(default_k),
536 prune_threshold: g.prune_threshold.unwrap_or_else(default_prune_threshold),
537 })
538 .unwrap_or_default();
539
540 Self {
541 core: ModeCore::new(storage, langbase),
542 got_pipe: config
543 .pipes
544 .got
545 .as_ref()
546 .and_then(|g| g.pipe.clone())
547 .unwrap_or_else(|| "got-reasoning-v1".to_string()),
548 config: got_config,
549 }
550 }
551
552 pub async fn initialize(&self, params: GotInitParams) -> AppResult<GotInitResult> {
554 let start = Instant::now();
555
556 if params.content.trim().is_empty() {
558 return Err(ToolError::Validation {
559 field: "content".to_string(),
560 reason: "Content cannot be empty".to_string(),
561 }
562 .into());
563 }
564
565 let session = self
567 .core
568 .storage()
569 .get_or_create_session(¶ms.session_id, "got")
570 .await?;
571
572 let effective_config = params.config.unwrap_or_else(|| self.config.clone());
574
575 let root_node = GraphNode::new(&session.id, ¶ms.content)
577 .with_type(NodeType::Root)
578 .with_depth(0)
579 .as_root()
580 .as_active();
581
582 self.core.storage().create_graph_node(&root_node).await?;
583
584 let latency = start.elapsed().as_millis() as i64;
585 info!(
586 session_id = %session.id,
587 root_node_id = %root_node.id,
588 latency_ms = latency,
589 "GoT graph initialized"
590 );
591
592 Ok(GotInitResult {
593 session_id: session.id,
594 root_node_id: root_node.id,
595 content: params.content,
596 config: effective_config,
597 })
598 }
599
600 pub async fn generate(&self, params: GotGenerateParams) -> AppResult<GotGenerateResult> {
602 let start = Instant::now();
603
604 let source_node = match ¶ms.node_id {
606 Some(id) => self
607 .core
608 .storage()
609 .get_graph_node(id)
610 .await?
611 .ok_or_else(|| ToolError::Validation {
612 field: "node_id".to_string(),
613 reason: format!("Node not found: {}", id),
614 })?,
615 None => {
616 let active = self
617 .core
618 .storage()
619 .get_active_graph_nodes(¶ms.session_id)
620 .await?;
621 active
622 .into_iter()
623 .next()
624 .ok_or_else(|| ToolError::Validation {
625 field: "session_id".to_string(),
626 reason: "No active nodes in session".to_string(),
627 })?
628 }
629 };
630
631 debug!(
632 session_id = %params.session_id,
633 source_node_id = %source_node.id,
634 k = params.k,
635 "Generating GoT continuations"
636 );
637
638 if source_node.depth >= self.config.max_depth as i32 {
640 return Err(ToolError::Validation {
641 field: "depth".to_string(),
642 reason: format!("Maximum depth {} reached", self.config.max_depth),
643 }
644 .into());
645 }
646
647 let messages =
649 self.build_generate_messages(&source_node, params.k, params.problem.as_deref());
650
651 let mut invocation = Invocation::new(
653 "reasoning.got.generate",
654 serialize_for_log(¶ms, "reasoning.got.generate input"),
655 )
656 .with_session(¶ms.session_id)
657 .with_pipe(&self.got_pipe);
658
659 let request = PipeRequest::new(&self.got_pipe, messages);
661 let response = match self.core.langbase().call_pipe(request).await {
662 Ok(resp) => resp,
663 Err(e) => {
664 let latency = start.elapsed().as_millis() as i64;
665 invocation = invocation.failure(e.to_string(), latency);
666 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
667 warn!(
668 error = %log_err,
669 tool = %invocation.tool_name,
670 "Failed to log invocation - audit trail incomplete"
671 );
672 }
673 return Err(e.into());
674 }
675 };
676
677 let gen_response = GenerateResponse::from_completion(&response.completion)?;
679
680 let mut continuations = Vec::new();
682 for item in gen_response.continuations.into_iter().take(params.k) {
683 let node = GraphNode::new(¶ms.session_id, &item.thought)
685 .with_type(NodeType::Thought)
686 .with_depth(source_node.depth + 1)
687 .with_score(item.confidence)
688 .as_active();
689
690 self.core.storage().create_graph_node(&node).await?;
691
692 let edge = GraphEdge::new(¶ms.session_id, &source_node.id, &node.id)
694 .with_type(EdgeType::Generates)
695 .with_weight(item.confidence);
696
697 self.core.storage().create_graph_edge(&edge).await?;
698
699 continuations.push(GeneratedContinuation {
700 node_id: node.id,
701 content: item.thought,
702 confidence: item.confidence,
703 novelty: item.novelty,
704 rationale: item.rationale,
705 });
706 }
707
708 let mut updated_source = source_node.clone();
710 updated_source.is_active = false;
711 self.core
712 .storage()
713 .update_graph_node(&updated_source)
714 .await?;
715
716 let latency = start.elapsed().as_millis() as i64;
717 invocation = invocation.success(
718 serialize_for_log(&continuations, "reasoning.got.generate output"),
719 latency,
720 );
721 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
722 warn!(
723 error = %log_err,
724 tool = %invocation.tool_name,
725 "Failed to log invocation - audit trail incomplete"
726 );
727 }
728
729 info!(
730 session_id = %params.session_id,
731 source_node_id = %source_node.id,
732 generated_count = continuations.len(),
733 latency_ms = latency,
734 "GoT generate completed"
735 );
736
737 Ok(GotGenerateResult {
738 session_id: params.session_id,
739 source_node_id: source_node.id,
740 continuations,
741 count: params.k,
742 })
743 }
744
745 pub async fn score(&self, params: GotScoreParams) -> AppResult<GotScoreResult> {
747 let start = Instant::now();
748
749 let node = self
751 .core
752 .storage()
753 .get_graph_node(¶ms.node_id)
754 .await?
755 .ok_or_else(|| ToolError::Validation {
756 field: "node_id".to_string(),
757 reason: format!("Node not found: {}", params.node_id),
758 })?;
759
760 debug!(
761 session_id = %params.session_id,
762 node_id = %node.id,
763 "Scoring GoT node"
764 );
765
766 let messages = self.build_score_messages(&node, params.problem.as_deref());
768
769 let mut invocation = Invocation::new(
771 "reasoning.got.score",
772 serialize_for_log(¶ms, "reasoning.got.score input"),
773 )
774 .with_session(¶ms.session_id)
775 .with_pipe(&self.got_pipe);
776
777 let request = PipeRequest::new(&self.got_pipe, messages);
779 let response = match self.core.langbase().call_pipe(request).await {
780 Ok(resp) => resp,
781 Err(e) => {
782 let latency = start.elapsed().as_millis() as i64;
783 invocation = invocation.failure(e.to_string(), latency);
784 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
785 warn!(
786 error = %log_err,
787 tool = %invocation.tool_name,
788 "Failed to log invocation - audit trail incomplete"
789 );
790 }
791 return Err(e.into());
792 }
793 };
794
795 let score_response = ScoreResponse::from_completion(&response.completion)?;
797
798 let mut updated_node = node.clone();
800 updated_node.score = Some(score_response.overall_score);
801 self.core.storage().update_graph_node(&updated_node).await?;
802
803 let latency = start.elapsed().as_millis() as i64;
804 invocation = invocation.success(
805 serialize_for_log(&score_response, "reasoning.got.score output"),
806 latency,
807 );
808 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
809 warn!(
810 error = %log_err,
811 tool = %invocation.tool_name,
812 "Failed to log invocation - audit trail incomplete"
813 );
814 }
815
816 info!(
817 session_id = %params.session_id,
818 node_id = %node.id,
819 score = score_response.overall_score,
820 latency_ms = latency,
821 "GoT score completed"
822 );
823
824 Ok(GotScoreResult {
825 session_id: params.session_id,
826 node_id: node.id,
827 overall_score: score_response.overall_score,
828 breakdown: ScoreBreakdown {
829 relevance: score_response.breakdown.relevance,
830 validity: score_response.breakdown.validity,
831 depth: score_response.breakdown.depth,
832 novelty: score_response.breakdown.novelty,
833 },
834 is_terminal_candidate: score_response.is_terminal_candidate,
835 rationale: score_response.rationale,
836 })
837 }
838
839 pub async fn aggregate(&self, params: GotAggregateParams) -> AppResult<GotAggregateResult> {
841 let start = Instant::now();
842
843 if params.node_ids.len() < 2 {
844 return Err(ToolError::Validation {
845 field: "node_ids".to_string(),
846 reason: "At least 2 nodes required for aggregation".to_string(),
847 }
848 .into());
849 }
850
851 let mut nodes = Vec::new();
853 for id in ¶ms.node_ids {
854 let node = self
855 .core
856 .storage()
857 .get_graph_node(id)
858 .await?
859 .ok_or_else(|| ToolError::Validation {
860 field: "node_ids".to_string(),
861 reason: format!("Node not found: {}", id),
862 })?;
863 nodes.push(node);
864 }
865
866 debug!(
867 session_id = %params.session_id,
868 node_count = nodes.len(),
869 "Aggregating GoT nodes"
870 );
871
872 let messages = self.build_aggregate_messages(&nodes, params.problem.as_deref());
874
875 let mut invocation = Invocation::new(
877 "reasoning.got.aggregate",
878 serialize_for_log(¶ms, "reasoning.got.aggregate input"),
879 )
880 .with_session(¶ms.session_id)
881 .with_pipe(&self.got_pipe);
882
883 let request = PipeRequest::new(&self.got_pipe, messages);
885 let response = match self.core.langbase().call_pipe(request).await {
886 Ok(resp) => resp,
887 Err(e) => {
888 let latency = start.elapsed().as_millis() as i64;
889 invocation = invocation.failure(e.to_string(), latency);
890 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
891 warn!(
892 error = %log_err,
893 tool = %invocation.tool_name,
894 "Failed to log invocation - audit trail incomplete"
895 );
896 }
897 return Err(e.into());
898 }
899 };
900
901 let agg_response = AggregateResponse::from_completion(&response.completion)?;
903
904 let max_depth = nodes.iter().map(|n| n.depth).max().unwrap_or(0);
906
907 let agg_node = GraphNode::new(¶ms.session_id, &agg_response.aggregated_thought)
909 .with_type(NodeType::Aggregation)
910 .with_depth(max_depth + 1)
911 .with_score(agg_response.confidence)
912 .as_active();
913
914 self.core.storage().create_graph_node(&agg_node).await?;
915
916 for node in &nodes {
918 let edge = GraphEdge::new(¶ms.session_id, &node.id, &agg_node.id)
919 .with_type(EdgeType::Aggregates);
920 self.core.storage().create_graph_edge(&edge).await?;
921
922 let mut updated = node.clone();
924 updated.is_active = false;
925 self.core.storage().update_graph_node(&updated).await?;
926 }
927
928 let latency = start.elapsed().as_millis() as i64;
929 invocation = invocation.success(
930 serialize_for_log(&agg_response, "reasoning.got.aggregate output"),
931 latency,
932 );
933 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
934 warn!(
935 error = %log_err,
936 tool = %invocation.tool_name,
937 "Failed to log invocation - audit trail incomplete"
938 );
939 }
940
941 info!(
942 session_id = %params.session_id,
943 aggregated_node_id = %agg_node.id,
944 source_count = nodes.len(),
945 latency_ms = latency,
946 "GoT aggregate completed"
947 );
948
949 Ok(GotAggregateResult {
950 session_id: params.session_id,
951 aggregated_node_id: agg_node.id,
952 content: agg_response.aggregated_thought,
953 confidence: agg_response.confidence,
954 source_nodes: params.node_ids,
955 synthesis_approach: agg_response.synthesis_approach,
956 conflicts_resolved: agg_response.conflicts_resolved,
957 })
958 }
959
960 pub async fn refine(&self, params: GotRefineParams) -> AppResult<GotRefineResult> {
962 let start = Instant::now();
963
964 let node = self
966 .core
967 .storage()
968 .get_graph_node(¶ms.node_id)
969 .await?
970 .ok_or_else(|| ToolError::Validation {
971 field: "node_id".to_string(),
972 reason: format!("Node not found: {}", params.node_id),
973 })?;
974
975 debug!(
976 session_id = %params.session_id,
977 node_id = %node.id,
978 "Refining GoT node"
979 );
980
981 let messages = self.build_refine_messages(&node, params.problem.as_deref());
983
984 let mut invocation = Invocation::new(
986 "reasoning.got.refine",
987 serialize_for_log(¶ms, "reasoning.got.refine input"),
988 )
989 .with_session(¶ms.session_id)
990 .with_pipe(&self.got_pipe);
991
992 let request = PipeRequest::new(&self.got_pipe, messages);
994 let response = match self.core.langbase().call_pipe(request).await {
995 Ok(resp) => resp,
996 Err(e) => {
997 let latency = start.elapsed().as_millis() as i64;
998 invocation = invocation.failure(e.to_string(), latency);
999 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
1000 warn!(
1001 error = %log_err,
1002 tool = %invocation.tool_name,
1003 "Failed to log invocation - audit trail incomplete"
1004 );
1005 }
1006 return Err(e.into());
1007 }
1008 };
1009
1010 let refine_response = RefineResponse::from_completion(&response.completion)?;
1012
1013 let refined_node = GraphNode::new(¶ms.session_id, &refine_response.refined_thought)
1015 .with_type(NodeType::Refinement)
1016 .with_depth(node.depth) .with_score(refine_response.confidence)
1018 .as_active();
1019
1020 self.core.storage().create_graph_node(&refined_node).await?;
1021
1022 let edge = GraphEdge::new(¶ms.session_id, &node.id, &refined_node.id)
1024 .with_type(EdgeType::Refines);
1025 self.core.storage().create_graph_edge(&edge).await?;
1026
1027 let mut updated_node = node.clone();
1029 updated_node.is_active = false;
1030 self.core.storage().update_graph_node(&updated_node).await?;
1031
1032 let latency = start.elapsed().as_millis() as i64;
1033 invocation = invocation.success(
1034 serialize_for_log(&refine_response, "reasoning.got.refine output"),
1035 latency,
1036 );
1037 if let Err(log_err) = self.core.storage().log_invocation(&invocation).await {
1038 warn!(
1039 error = %log_err,
1040 tool = %invocation.tool_name,
1041 "Failed to log invocation - audit trail incomplete"
1042 );
1043 }
1044
1045 info!(
1046 session_id = %params.session_id,
1047 original_node_id = %node.id,
1048 refined_node_id = %refined_node.id,
1049 quality_delta = refine_response.quality_delta,
1050 latency_ms = latency,
1051 "GoT refine completed"
1052 );
1053
1054 Ok(GotRefineResult {
1055 session_id: params.session_id,
1056 original_node_id: node.id,
1057 refined_node_id: refined_node.id,
1058 content: refine_response.refined_thought,
1059 confidence: refine_response.confidence,
1060 improvements_made: refine_response.improvements_made,
1061 quality_delta: refine_response.quality_delta,
1062 })
1063 }
1064
1065 pub async fn prune(&self, params: GotPruneParams) -> AppResult<GotPruneResult> {
1067 let start = Instant::now();
1068
1069 let threshold = params.threshold.unwrap_or(self.config.prune_threshold);
1070
1071 let nodes = self
1073 .core
1074 .storage()
1075 .get_session_graph_nodes(¶ms.session_id)
1076 .await?;
1077
1078 let mut pruned_ids = Vec::new();
1080 for node in &nodes {
1081 if node.is_root || node.is_terminal {
1083 continue;
1084 }
1085
1086 if let Some(score) = node.score {
1088 if score < threshold {
1089 let children = self.core.storage().get_edges_from(&node.id).await?;
1091 if children.is_empty() {
1092 pruned_ids.push(node.id.clone());
1093 }
1094 }
1095 }
1096 }
1097
1098 for id in &pruned_ids {
1100 let edges_from = self.core.storage().get_edges_from(id).await?;
1102 let edges_to = self.core.storage().get_edges_to(id).await?;
1103
1104 for edge in edges_from.iter().chain(edges_to.iter()) {
1105 self.core.storage().delete_graph_edge(&edge.id).await?;
1106 }
1107
1108 self.core.storage().delete_graph_node(id).await?;
1110 }
1111
1112 let remaining_count = nodes.len() - pruned_ids.len();
1113 let latency = start.elapsed().as_millis() as i64;
1114
1115 info!(
1116 session_id = %params.session_id,
1117 pruned_count = pruned_ids.len(),
1118 remaining_count = remaining_count,
1119 threshold = threshold,
1120 latency_ms = latency,
1121 "GoT prune completed"
1122 );
1123
1124 Ok(GotPruneResult {
1125 session_id: params.session_id,
1126 pruned_count: pruned_ids.len(),
1127 remaining_count,
1128 threshold_used: threshold,
1129 pruned_node_ids: pruned_ids,
1130 })
1131 }
1132
1133 pub async fn finalize(&self, params: GotFinalizeParams) -> AppResult<GotFinalizeResult> {
1135 let start = Instant::now();
1136
1137 let nodes_to_finalize = if params.terminal_node_ids.is_empty() {
1138 let active = self
1140 .core
1141 .storage()
1142 .get_active_graph_nodes(¶ms.session_id)
1143 .await?;
1144 let mut scored: Vec<_> = active.into_iter().filter(|n| n.score.is_some()).collect();
1145 scored.sort_by(|a, b| {
1146 b.score
1147 .unwrap_or(0.0)
1148 .partial_cmp(&a.score.unwrap_or(0.0))
1149 .unwrap_or(std::cmp::Ordering::Equal)
1150 });
1151 scored.into_iter().take(3).collect::<Vec<_>>()
1153 } else {
1154 let mut nodes = Vec::new();
1156 for id in ¶ms.terminal_node_ids {
1157 let node = self
1158 .core
1159 .storage()
1160 .get_graph_node(id)
1161 .await?
1162 .ok_or_else(|| ToolError::Validation {
1163 field: "terminal_node_ids".to_string(),
1164 reason: format!("Node not found: {}", id),
1165 })?;
1166 nodes.push(node);
1167 }
1168 nodes
1169 };
1170
1171 let mut conclusions = Vec::new();
1173 for node in nodes_to_finalize {
1174 let mut updated = node.clone();
1175 updated.is_terminal = true;
1176 updated.is_active = false;
1177 updated.node_type = NodeType::Terminal;
1178 self.core.storage().update_graph_node(&updated).await?;
1179
1180 conclusions.push(TerminalConclusion {
1181 node_id: node.id,
1182 content: node.content,
1183 score: node.score,
1184 depth: node.depth,
1185 });
1186 }
1187
1188 let latency = start.elapsed().as_millis() as i64;
1189 info!(
1190 session_id = %params.session_id,
1191 terminal_count = conclusions.len(),
1192 latency_ms = latency,
1193 "GoT finalize completed"
1194 );
1195
1196 Ok(GotFinalizeResult {
1197 session_id: params.session_id,
1198 terminal_count: conclusions.len(),
1199 conclusions,
1200 })
1201 }
1202
1203 pub async fn get_state(&self, params: GotGetStateParams) -> AppResult<GotStateResult> {
1205 let nodes = self
1206 .core
1207 .storage()
1208 .get_session_graph_nodes(¶ms.session_id)
1209 .await?;
1210 let edges = self
1211 .core
1212 .storage()
1213 .get_session_edges(¶ms.session_id)
1214 .await?;
1215
1216 let active_nodes: Vec<_> = nodes.iter().filter(|n| n.is_active).collect();
1217 let terminal_nodes: Vec<_> = nodes.iter().filter(|n| n.is_terminal).collect();
1218 let root_nodes: Vec<_> = nodes.iter().filter(|n| n.is_root).collect();
1219 let max_depth = nodes.iter().map(|n| n.depth).max().unwrap_or(0);
1220
1221 Ok(GotStateResult {
1222 session_id: params.session_id,
1223 total_nodes: nodes.len(),
1224 active_nodes: active_nodes.len(),
1225 terminal_nodes: terminal_nodes.len(),
1226 total_edges: edges.len(),
1227 max_depth,
1228 root_node_ids: root_nodes.iter().map(|n| n.id.clone()).collect(),
1229 active_node_ids: active_nodes.iter().map(|n| n.id.clone()).collect(),
1230 terminal_node_ids: terminal_nodes.iter().map(|n| n.id.clone()).collect(),
1231 })
1232 }
1233
1234 pub async fn has_cycle(&self, session_id: &str) -> AppResult<bool> {
1236 let nodes = self
1237 .core
1238 .storage()
1239 .get_session_graph_nodes(session_id)
1240 .await?;
1241 let edges = self.core.storage().get_session_edges(session_id).await?;
1242
1243 let mut adj: std::collections::HashMap<String, Vec<String>> =
1245 std::collections::HashMap::new();
1246 for edge in &edges {
1247 adj.entry(edge.from_node.clone())
1248 .or_default()
1249 .push(edge.to_node.clone());
1250 }
1251
1252 let mut visited = HashSet::new();
1254 let mut rec_stack = HashSet::new();
1255
1256 fn dfs(
1257 node: &str,
1258 adj: &std::collections::HashMap<String, Vec<String>>,
1259 visited: &mut HashSet<String>,
1260 rec_stack: &mut HashSet<String>,
1261 ) -> bool {
1262 visited.insert(node.to_string());
1263 rec_stack.insert(node.to_string());
1264
1265 if let Some(neighbors) = adj.get(node) {
1266 for neighbor in neighbors {
1267 if !visited.contains(neighbor) {
1268 if dfs(neighbor, adj, visited, rec_stack) {
1269 return true;
1270 }
1271 } else if rec_stack.contains(neighbor) {
1272 return true;
1273 }
1274 }
1275 }
1276
1277 rec_stack.remove(node);
1278 false
1279 }
1280
1281 for node in &nodes {
1282 if !visited.contains(&node.id) && dfs(&node.id, &adj, &mut visited, &mut rec_stack) {
1283 return Ok(true);
1284 }
1285 }
1286
1287 Ok(false)
1288 }
1289
1290 fn build_generate_messages(
1295 &self,
1296 source_node: &GraphNode,
1297 k: usize,
1298 problem: Option<&str>,
1299 ) -> Vec<Message> {
1300 let mut messages = Vec::new();
1301 messages.push(Message::system(GOT_GENERATE_PROMPT));
1302
1303 let mut user_msg = format!(
1304 "Generate {} diverse continuations from this thought:\n\n\"{}\"",
1305 k, source_node.content
1306 );
1307
1308 if let Some(p) = problem {
1309 user_msg.push_str(&format!("\n\nProblem context: {}", p));
1310 }
1311
1312 user_msg.push_str(&format!("\n\nCurrent depth: {}", source_node.depth));
1313
1314 messages.push(Message::user(user_msg));
1315 messages
1316 }
1317
1318 fn build_score_messages(&self, node: &GraphNode, problem: Option<&str>) -> Vec<Message> {
1319 let mut messages = Vec::new();
1320 messages.push(Message::system(GOT_SCORE_PROMPT));
1321
1322 let mut user_msg = format!("Score this thought:\n\n\"{}\"", node.content);
1323
1324 if let Some(p) = problem {
1325 user_msg.push_str(&format!("\n\nProblem context: {}", p));
1326 }
1327
1328 user_msg.push_str(&format!("\n\nDepth: {}", node.depth));
1329 if let Some(score) = node.score {
1330 user_msg.push_str(&format!("\nPrevious score: {}", score));
1331 }
1332
1333 messages.push(Message::user(user_msg));
1334 messages
1335 }
1336
1337 fn build_aggregate_messages(&self, nodes: &[GraphNode], problem: Option<&str>) -> Vec<Message> {
1338 let mut messages = Vec::new();
1339 messages.push(Message::system(GOT_AGGREGATE_PROMPT));
1340
1341 let thoughts: Vec<String> = nodes
1342 .iter()
1343 .enumerate()
1344 .map(|(i, n)| format!("{}. \"{}\"", i + 1, n.content))
1345 .collect();
1346
1347 let mut user_msg = format!(
1348 "Aggregate these {} thoughts into a unified insight:\n\n{}",
1349 nodes.len(),
1350 thoughts.join("\n\n")
1351 );
1352
1353 if let Some(p) = problem {
1354 user_msg.push_str(&format!("\n\nProblem context: {}", p));
1355 }
1356
1357 messages.push(Message::user(user_msg));
1358 messages
1359 }
1360
1361 fn build_refine_messages(&self, node: &GraphNode, problem: Option<&str>) -> Vec<Message> {
1362 let mut messages = Vec::new();
1363 messages.push(Message::system(GOT_REFINE_PROMPT));
1364
1365 let mut user_msg = format!("Refine and improve this thought:\n\n\"{}\"", node.content);
1366
1367 if let Some(p) = problem {
1368 user_msg.push_str(&format!("\n\nProblem context: {}", p));
1369 }
1370
1371 if let Some(score) = node.score {
1372 user_msg.push_str(&format!("\n\nCurrent score: {:.2}", score));
1373 }
1374
1375 messages.push(Message::user(user_msg));
1376 messages
1377 }
1378}
1379
1380impl GotInitParams {
1385 pub fn new(content: impl Into<String>) -> Self {
1387 Self {
1388 content: content.into(),
1389 problem: None,
1390 session_id: None,
1391 config: None,
1392 }
1393 }
1394
1395 pub fn with_problem(mut self, problem: impl Into<String>) -> Self {
1397 self.problem = Some(problem.into());
1398 self
1399 }
1400
1401 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
1403 self.session_id = Some(session_id.into());
1404 self
1405 }
1406
1407 pub fn with_config(mut self, config: GotConfig) -> Self {
1409 self.config = Some(config);
1410 self
1411 }
1412}
1413
1414impl GotGenerateParams {
1415 pub fn new(session_id: impl Into<String>) -> Self {
1417 Self {
1418 session_id: session_id.into(),
1419 node_id: None,
1420 k: default_k(),
1421 problem: None,
1422 }
1423 }
1424
1425 pub fn with_node(mut self, node_id: impl Into<String>) -> Self {
1427 self.node_id = Some(node_id.into());
1428 self
1429 }
1430
1431 pub fn with_k(mut self, k: usize) -> Self {
1433 self.k = k;
1434 self
1435 }
1436
1437 pub fn with_problem(mut self, problem: impl Into<String>) -> Self {
1439 self.problem = Some(problem.into());
1440 self
1441 }
1442}
1443
1444impl GotScoreParams {
1445 pub fn new(session_id: impl Into<String>, node_id: impl Into<String>) -> Self {
1447 Self {
1448 session_id: session_id.into(),
1449 node_id: node_id.into(),
1450 problem: None,
1451 }
1452 }
1453
1454 pub fn with_problem(mut self, problem: impl Into<String>) -> Self {
1456 self.problem = Some(problem.into());
1457 self
1458 }
1459}
1460
1461impl GotAggregateParams {
1462 pub fn new(session_id: impl Into<String>, node_ids: Vec<String>) -> Self {
1464 Self {
1465 session_id: session_id.into(),
1466 node_ids,
1467 problem: None,
1468 }
1469 }
1470
1471 pub fn with_problem(mut self, problem: impl Into<String>) -> Self {
1473 self.problem = Some(problem.into());
1474 self
1475 }
1476}
1477
1478impl GotRefineParams {
1479 pub fn new(session_id: impl Into<String>, node_id: impl Into<String>) -> Self {
1481 Self {
1482 session_id: session_id.into(),
1483 node_id: node_id.into(),
1484 problem: None,
1485 }
1486 }
1487
1488 pub fn with_problem(mut self, problem: impl Into<String>) -> Self {
1490 self.problem = Some(problem.into());
1491 self
1492 }
1493}
1494
1495impl GotPruneParams {
1496 pub fn new(session_id: impl Into<String>) -> Self {
1498 Self {
1499 session_id: session_id.into(),
1500 threshold: None,
1501 }
1502 }
1503
1504 pub fn with_threshold(mut self, threshold: f64) -> Self {
1506 self.threshold = Some(threshold);
1507 self
1508 }
1509}
1510
1511impl GotFinalizeParams {
1512 pub fn new(session_id: impl Into<String>) -> Self {
1514 Self {
1515 session_id: session_id.into(),
1516 terminal_node_ids: vec![],
1517 }
1518 }
1519
1520 pub fn with_terminal_nodes(mut self, node_ids: Vec<String>) -> Self {
1522 self.terminal_node_ids = node_ids;
1523 self
1524 }
1525}
1526
1527impl GotGetStateParams {
1528 pub fn new(session_id: impl Into<String>) -> Self {
1530 Self {
1531 session_id: session_id.into(),
1532 }
1533 }
1534}