Skip to main content

agm_core/model/
context.rs

1//! Agent context types (spec S25).
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5
6// ---------------------------------------------------------------------------
7// FileRange
8// ---------------------------------------------------------------------------
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum FileRange {
12    Full,
13    Lines(u64, u64),
14    Function(String),
15}
16
17impl FileRange {
18    #[must_use]
19    pub fn full() -> Self {
20        Self::Full
21    }
22
23    #[must_use]
24    pub fn lines(start: u64, end: u64) -> Self {
25        Self::Lines(start, end)
26    }
27
28    #[must_use]
29    pub fn function(name: impl Into<String>) -> Self {
30        Self::Function(name.into())
31    }
32}
33
34impl serde::Serialize for FileRange {
35    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
36        match self {
37            Self::Full => serializer.serialize_str("full"),
38            Self::Lines(start, end) => {
39                use serde::ser::SerializeSeq;
40                let mut seq = serializer.serialize_seq(Some(2))?;
41                seq.serialize_element(start)?;
42                seq.serialize_element(end)?;
43                seq.end()
44            }
45            Self::Function(name) => serializer.serialize_str(&format!("function: {name}")),
46        }
47    }
48}
49
50impl<'de> serde::Deserialize<'de> for FileRange {
51    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
52        use serde::de;
53
54        struct FileRangeVisitor;
55
56        impl<'de> de::Visitor<'de> for FileRangeVisitor {
57            type Value = FileRange;
58
59            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60                f.write_str("\"full\", [start, end], or \"function: <name>\"")
61            }
62
63            fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
64                if v == "full" {
65                    Ok(FileRange::Full)
66                } else if let Some(name) = v.strip_prefix("function:") {
67                    Ok(FileRange::Function(name.trim().to_owned()))
68                } else {
69                    Err(E::custom(format!("invalid file range: {v:?}")))
70                }
71            }
72
73            fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
74                let start: u64 = seq
75                    .next_element()?
76                    .ok_or_else(|| de::Error::invalid_length(0, &"2"))?;
77                let end: u64 = seq
78                    .next_element()?
79                    .ok_or_else(|| de::Error::invalid_length(1, &"2"))?;
80                Ok(FileRange::Lines(start, end))
81            }
82        }
83
84        deserializer.deserialize_any(FileRangeVisitor)
85    }
86}
87
88// ---------------------------------------------------------------------------
89// LoadFile
90// ---------------------------------------------------------------------------
91
92#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
93pub struct LoadFile {
94    pub path: String,
95    pub range: FileRange,
96}
97
98// ---------------------------------------------------------------------------
99// AgentContext
100// ---------------------------------------------------------------------------
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
103pub struct AgentContext {
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub load_nodes: Option<Vec<String>>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub load_files: Option<Vec<LoadFile>>,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub system_hint: Option<String>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub max_tokens: Option<u64>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub load_memory: Option<Vec<String>>,
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_file_range_full_serde() {
122        let r = FileRange::Full;
123        let json = serde_json::to_string(&r).unwrap();
124        assert_eq!(json, "\"full\"");
125        let back: FileRange = serde_json::from_str(&json).unwrap();
126        assert_eq!(r, back);
127    }
128
129    #[test]
130    fn test_file_range_lines_serde() {
131        let r = FileRange::Lines(1, 50);
132        let json = serde_json::to_string(&r).unwrap();
133        assert_eq!(json, "[1,50]");
134        let back: FileRange = serde_json::from_str(&json).unwrap();
135        assert_eq!(r, back);
136    }
137
138    #[test]
139    fn test_file_range_function_serde() {
140        let r = FileRange::Function("handle_request".to_owned());
141        let json = serde_json::to_string(&r).unwrap();
142        assert_eq!(json, "\"function: handle_request\"");
143        let back: FileRange = serde_json::from_str(&json).unwrap();
144        assert_eq!(r, back);
145    }
146
147    #[test]
148    fn test_load_file_serde_roundtrip() {
149        let lf = LoadFile {
150            path: "src/main.rs".to_owned(),
151            range: FileRange::Full,
152        };
153        let json = serde_json::to_string(&lf).unwrap();
154        let back: LoadFile = serde_json::from_str(&json).unwrap();
155        assert_eq!(lf, back);
156    }
157
158    #[test]
159    fn test_agent_context_full_serde() {
160        let ctx = AgentContext {
161            load_nodes: Some(vec!["auth.login".to_owned()]),
162            load_files: Some(vec![LoadFile {
163                path: "src/auth.rs".to_owned(),
164                range: FileRange::Lines(1, 50),
165            }]),
166            system_hint: Some("Rust project".to_owned()),
167            max_tokens: Some(4000),
168            load_memory: Some(vec!["rust.repository".to_owned()]),
169        };
170        let json = serde_json::to_string(&ctx).unwrap();
171        let back: AgentContext = serde_json::from_str(&json).unwrap();
172        assert_eq!(ctx, back);
173    }
174
175    #[test]
176    fn test_agent_context_minimal_serde() {
177        let ctx = AgentContext {
178            load_nodes: None,
179            load_files: None,
180            system_hint: Some("hint".to_owned()),
181            max_tokens: None,
182            load_memory: None,
183        };
184        let json = serde_json::to_string(&ctx).unwrap();
185        assert!(!json.contains("load_nodes"));
186        assert!(!json.contains("load_files"));
187        assert!(!json.contains("max_tokens"));
188        assert!(!json.contains("load_memory"));
189        let back: AgentContext = serde_json::from_str(&json).unwrap();
190        assert_eq!(ctx, back);
191    }
192
193    #[test]
194    fn test_agent_context_deserialize_from_spec_json() {
195        let json = r#"{
196            "load_nodes": ["auth.constraints", "auth.session"],
197            "load_files": [
198                {"path": "src/handlers/auth.rs", "range": "full"}
199            ],
200            "system_hint": "Rust project using actix-web."
201        }"#;
202        let ctx: AgentContext = serde_json::from_str(json).unwrap();
203        assert_eq!(ctx.load_nodes.as_ref().unwrap().len(), 2);
204        assert_eq!(ctx.load_files.as_ref().unwrap().len(), 1);
205        assert_eq!(ctx.load_files.as_ref().unwrap()[0].range, FileRange::Full);
206    }
207}