agm_core/model/
context.rs1use serde::{Deserialize, Serialize};
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq)]
11pub enum FileRange {
12 Full,
13 Lines(u64, u64),
14 Function(String),
15}
16
17impl FileRange {
18 #[must_use]
19 pub fn full() -> Self {
20 Self::Full
21 }
22
23 #[must_use]
24 pub fn lines(start: u64, end: u64) -> Self {
25 Self::Lines(start, end)
26 }
27
28 #[must_use]
29 pub fn function(name: impl Into<String>) -> Self {
30 Self::Function(name.into())
31 }
32}
33
34impl serde::Serialize for FileRange {
35 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
36 match self {
37 Self::Full => serializer.serialize_str("full"),
38 Self::Lines(start, end) => {
39 use serde::ser::SerializeSeq;
40 let mut seq = serializer.serialize_seq(Some(2))?;
41 seq.serialize_element(start)?;
42 seq.serialize_element(end)?;
43 seq.end()
44 }
45 Self::Function(name) => serializer.serialize_str(&format!("function: {name}")),
46 }
47 }
48}
49
50impl<'de> serde::Deserialize<'de> for FileRange {
51 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
52 use serde::de;
53
54 struct FileRangeVisitor;
55
56 impl<'de> de::Visitor<'de> for FileRangeVisitor {
57 type Value = FileRange;
58
59 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.write_str("\"full\", [start, end], or \"function: <name>\"")
61 }
62
63 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
64 if v == "full" {
65 Ok(FileRange::Full)
66 } else if let Some(name) = v.strip_prefix("function:") {
67 Ok(FileRange::Function(name.trim().to_owned()))
68 } else {
69 Err(E::custom(format!("invalid file range: {v:?}")))
70 }
71 }
72
73 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
74 let start: u64 = seq
75 .next_element()?
76 .ok_or_else(|| de::Error::invalid_length(0, &"2"))?;
77 let end: u64 = seq
78 .next_element()?
79 .ok_or_else(|| de::Error::invalid_length(1, &"2"))?;
80 Ok(FileRange::Lines(start, end))
81 }
82 }
83
84 deserializer.deserialize_any(FileRangeVisitor)
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub struct LoadFile {
94 pub path: String,
95 pub range: FileRange,
96}
97
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
103pub struct AgentContext {
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub load_nodes: Option<Vec<String>>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub load_files: Option<Vec<LoadFile>>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub system_hint: Option<String>,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub max_tokens: Option<u64>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 pub load_memory: Option<Vec<String>>,
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_file_range_full_serde() {
122 let r = FileRange::Full;
123 let json = serde_json::to_string(&r).unwrap();
124 assert_eq!(json, "\"full\"");
125 let back: FileRange = serde_json::from_str(&json).unwrap();
126 assert_eq!(r, back);
127 }
128
129 #[test]
130 fn test_file_range_lines_serde() {
131 let r = FileRange::Lines(1, 50);
132 let json = serde_json::to_string(&r).unwrap();
133 assert_eq!(json, "[1,50]");
134 let back: FileRange = serde_json::from_str(&json).unwrap();
135 assert_eq!(r, back);
136 }
137
138 #[test]
139 fn test_file_range_function_serde() {
140 let r = FileRange::Function("handle_request".to_owned());
141 let json = serde_json::to_string(&r).unwrap();
142 assert_eq!(json, "\"function: handle_request\"");
143 let back: FileRange = serde_json::from_str(&json).unwrap();
144 assert_eq!(r, back);
145 }
146
147 #[test]
148 fn test_load_file_serde_roundtrip() {
149 let lf = LoadFile {
150 path: "src/main.rs".to_owned(),
151 range: FileRange::Full,
152 };
153 let json = serde_json::to_string(&lf).unwrap();
154 let back: LoadFile = serde_json::from_str(&json).unwrap();
155 assert_eq!(lf, back);
156 }
157
158 #[test]
159 fn test_agent_context_full_serde() {
160 let ctx = AgentContext {
161 load_nodes: Some(vec!["auth.login".to_owned()]),
162 load_files: Some(vec![LoadFile {
163 path: "src/auth.rs".to_owned(),
164 range: FileRange::Lines(1, 50),
165 }]),
166 system_hint: Some("Rust project".to_owned()),
167 max_tokens: Some(4000),
168 load_memory: Some(vec!["rust.repository".to_owned()]),
169 };
170 let json = serde_json::to_string(&ctx).unwrap();
171 let back: AgentContext = serde_json::from_str(&json).unwrap();
172 assert_eq!(ctx, back);
173 }
174
175 #[test]
176 fn test_agent_context_minimal_serde() {
177 let ctx = AgentContext {
178 load_nodes: None,
179 load_files: None,
180 system_hint: Some("hint".to_owned()),
181 max_tokens: None,
182 load_memory: None,
183 };
184 let json = serde_json::to_string(&ctx).unwrap();
185 assert!(!json.contains("load_nodes"));
186 assert!(!json.contains("load_files"));
187 assert!(!json.contains("max_tokens"));
188 assert!(!json.contains("load_memory"));
189 let back: AgentContext = serde_json::from_str(&json).unwrap();
190 assert_eq!(ctx, back);
191 }
192
193 #[test]
194 fn test_agent_context_deserialize_from_spec_json() {
195 let json = r#"{
196 "load_nodes": ["auth.constraints", "auth.session"],
197 "load_files": [
198 {"path": "src/handlers/auth.rs", "range": "full"}
199 ],
200 "system_hint": "Rust project using actix-web."
201 }"#;
202 let ctx: AgentContext = serde_json::from_str(json).unwrap();
203 assert_eq!(ctx.load_nodes.as_ref().unwrap().len(), 2);
204 assert_eq!(ctx.load_files.as_ref().unwrap().len(), 1);
205 assert_eq!(ctx.load_files.as_ref().unwrap()[0].range, FileRange::Full);
206 }
207}