rab/agent/
footer_data_provider.rs1use std::collections::BTreeMap;
2use std::fs;
3use std::path::{Path, PathBuf};
4
5pub struct FooterDataProvider {
19 cwd: PathBuf,
20 git_branch: Option<String>,
21 extension_statuses: BTreeMap<String, String>,
22 available_provider_count: usize,
23 model_provider: Option<String>,
25 model_id: Option<String>,
27}
28
29impl FooterDataProvider {
30 pub fn new(cwd: PathBuf) -> Self {
31 let mut provider = Self {
32 cwd,
33 git_branch: None,
34 extension_statuses: BTreeMap::new(),
35 available_provider_count: 1,
36 model_provider: None,
37 model_id: None,
38 };
39 provider.refresh_git_branch();
40 provider
41 }
42
43 pub fn get_git_branch(&self) -> Option<&str> {
46 self.git_branch.as_deref()
47 }
48
49 pub fn refresh_git_branch(&mut self) {
51 self.git_branch = resolve_git_branch(&self.cwd);
52 }
53
54 pub fn set_cwd(&mut self, cwd: PathBuf) {
55 self.cwd = cwd;
56 self.refresh_git_branch();
57 }
58
59 pub fn get_extension_statuses(&self) -> &BTreeMap<String, String> {
62 &self.extension_statuses
63 }
64
65 pub fn set_extension_status(&mut self, key: &str, text: Option<&str>) {
66 if let Some(text) = text {
67 self.extension_statuses
68 .insert(key.to_string(), text.to_string());
69 } else {
70 self.extension_statuses.remove(key);
71 }
72 }
73
74 pub fn clear_extension_statuses(&mut self) {
75 self.extension_statuses.clear();
76 }
77
78 pub fn get_model_provider(&self) -> Option<&str> {
81 self.model_provider.as_deref()
82 }
83
84 pub fn get_model_id(&self) -> Option<&str> {
85 self.model_id.as_deref()
86 }
87
88 pub fn refresh_from_session(&mut self, session: &crate::agent::session::Session) {
91 let mut latest_provider: Option<String> = None;
92 let mut latest_model_id: Option<String> = None;
93
94 for entry in session.get_entries() {
95 if let crate::agent::session::SessionEntry::ModelChange(e) = entry {
96 latest_provider = Some(e.provider.clone());
97 latest_model_id = Some(e.model_id.clone());
98 }
99 }
100
101 self.model_provider = latest_provider;
102 self.model_id = latest_model_id;
103 }
104
105 #[cfg(test)]
107 pub fn set_test_model_provider(&mut self, provider: Option<&str>) {
108 self.model_provider = provider.map(|s| s.to_string());
109 }
110
111 #[cfg(test)]
113 pub fn set_test_model_id(&mut self, model_id: Option<&str>) {
114 self.model_id = model_id.map(|s| s.to_string());
115 }
116
117 pub fn get_available_provider_count(&self) -> usize {
118 self.available_provider_count
119 }
120
121 pub fn set_available_provider_count(&mut self, count: usize) {
122 self.available_provider_count = count;
123 }
124
125 #[cfg(test)]
127 pub fn set_test_git_branch(&mut self, branch: Option<&str>) {
128 self.git_branch = branch.map(|s| s.to_string());
129 }
130}
131
132#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_new_provider_refreshes_git_branch() {
140 let provider = FooterDataProvider::new(PathBuf::from("/tmp"));
141 assert!(provider.get_git_branch().is_none());
143 }
144
145 #[test]
146 fn test_set_test_git_branch() {
147 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
148 provider.set_test_git_branch(Some("main"));
149 assert_eq!(provider.get_git_branch(), Some("main"));
150 }
151
152 #[test]
153 fn test_set_test_git_branch_none() {
154 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
155 provider.set_test_git_branch(Some("feature"));
156 provider.set_test_git_branch(None);
157 assert!(provider.get_git_branch().is_none());
158 }
159
160 #[test]
161 fn test_extension_statuses() {
162 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
163 assert!(provider.get_extension_statuses().is_empty());
164
165 provider.set_extension_status("bash", Some("ready"));
166 assert_eq!(
167 provider.get_extension_statuses().get("bash"),
168 Some(&"ready".to_string())
169 );
170
171 provider.set_extension_status("bash", None);
172 assert!(provider.get_extension_statuses().is_empty());
173 }
174
175 #[test]
176 fn test_extension_statuses_sorted() {
177 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
178 provider.set_extension_status("zzz", Some("last"));
179 provider.set_extension_status("aaa", Some("first"));
180 provider.set_extension_status("mmm", Some("middle"));
181
182 let keys: Vec<&String> = provider.get_extension_statuses().keys().collect();
183 assert_eq!(keys, vec!["aaa", "mmm", "zzz"]);
184 }
185
186 #[test]
187 fn test_clear_extension_statuses() {
188 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
189 provider.set_extension_status("bash", Some("ready"));
190 provider.clear_extension_statuses();
191 assert!(provider.get_extension_statuses().is_empty());
192 }
193
194 #[test]
195 fn test_provider_count() {
196 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
197 assert_eq!(provider.get_available_provider_count(), 1);
198 provider.set_available_provider_count(3);
199 assert_eq!(provider.get_available_provider_count(), 3);
200 }
201
202 #[test]
203 fn test_set_cwd_refreshes_git_branch() {
204 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
205 provider.set_test_git_branch(Some("old-branch"));
206 provider.set_cwd(PathBuf::from("/nonexistent"));
208 assert!(provider.get_git_branch().is_none());
209 }
210
211 #[test]
214 fn test_model_provider_defaults() {
215 let provider = FooterDataProvider::new(PathBuf::from("/tmp"));
216 assert!(provider.get_model_provider().is_none());
217 assert!(provider.get_model_id().is_none());
218 }
219
220 #[test]
221 fn test_set_test_model_provider() {
222 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
223 provider.set_test_model_provider(Some("opencode-go"));
224 assert_eq!(provider.get_model_provider(), Some("opencode-go"));
225 provider.set_test_model_provider(None);
226 assert!(provider.get_model_provider().is_none());
227 }
228
229 #[test]
230 fn test_set_test_model_id() {
231 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
232 provider.set_test_model_id(Some("deepseek-v4-flash"));
233 assert_eq!(provider.get_model_id(), Some("deepseek-v4-flash"));
234 provider.set_test_model_id(None);
235 assert!(provider.get_model_id().is_none());
236 }
237
238 #[test]
239 fn test_refresh_from_session_extracts_latest_model_change() {
240 use crate::agent::SessionMetadata;
241 use crate::agent::session::InMemorySessionStorage;
242 use crate::agent::session::*;
243
244 let meta = SessionMetadata {
245 id: "test".into(),
246 created_at: String::new(),
247 cwd: "/tmp".into(),
248 path: None,
249 parent_session_path: None,
250 };
251 let storage = InMemorySessionStorage::new(meta);
252 let mut session = Session::new(Box::new(storage));
253 session.append_model_change("provider-a", "model-a");
254 session.append_model_change("provider-b", "model-b");
255
256 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
257 provider.refresh_from_session(&session);
258
259 assert_eq!(provider.get_model_provider(), Some("provider-b"));
260 assert_eq!(provider.get_model_id(), Some("model-b"));
261 }
262
263 #[test]
264 fn test_refresh_from_session_no_model_change() {
265 use crate::agent::SessionMetadata;
266 use crate::agent::session::InMemorySessionStorage;
267 use crate::agent::session::*;
268
269 let meta = SessionMetadata {
270 id: "test".into(),
271 created_at: String::new(),
272 cwd: "/tmp".into(),
273 path: None,
274 parent_session_path: None,
275 };
276 let storage = InMemorySessionStorage::new(meta);
277 let session = Session::new(Box::new(storage));
278
279 let mut provider = FooterDataProvider::new(PathBuf::from("/tmp"));
280 provider.set_test_model_provider(Some("old"));
282 provider.set_test_model_id(Some("old-model"));
283 provider.refresh_from_session(&session);
285
286 assert!(provider.get_model_provider().is_none());
287 assert!(provider.get_model_id().is_none());
288 }
289
290 #[test]
293 fn test_find_git_paths_no_git() {
294 let tmp = std::env::temp_dir().join(format!("rab-test-{}", uuid::Uuid::new_v4()));
295 std::fs::create_dir_all(&tmp).unwrap();
296 let result = find_git_paths(&tmp);
297 assert!(result.is_none());
298 let _ = std::fs::remove_dir_all(&tmp);
299 }
300
301 #[test]
302 fn test_find_git_paths_regular_repo() {
303 let tmp = std::env::temp_dir().join(format!("rab-test-{}", uuid::Uuid::new_v4()));
304 std::fs::create_dir_all(&tmp.join(".git")).unwrap();
305 std::fs::write(&tmp.join(".git").join("HEAD"), "ref: refs/heads/main\n").unwrap();
306
307 let result = find_git_paths(&tmp);
308 assert!(result.is_some());
309 let paths = result.unwrap();
310 assert_eq!(paths.head_path, tmp.join(".git").join("HEAD"));
311
312 let _ = std::fs::remove_dir_all(&tmp);
313 }
314
315 #[test]
316 fn test_find_git_paths_walk_up() {
317 let tmp = std::env::temp_dir().join(format!("rab-test-{}", uuid::Uuid::new_v4()));
318 std::fs::create_dir_all(&tmp.join("sub").join("deep")).unwrap();
319 std::fs::create_dir_all(&tmp.join(".git")).unwrap();
320 std::fs::write(&tmp.join(".git").join("HEAD"), "ref: refs/heads/main\n").unwrap();
321
322 let result = find_git_paths(&tmp.join("sub").join("deep"));
324 assert!(result.is_some());
325
326 let _ = std::fs::remove_dir_all(&tmp);
327 }
328
329 #[test]
330 fn test_resolve_git_branch_from_head() {
331 let tmp = std::env::temp_dir().join(format!("rab-test-{}", uuid::Uuid::new_v4()));
332 std::fs::create_dir_all(&tmp.join(".git")).unwrap();
333 std::fs::write(
334 &tmp.join(".git").join("HEAD"),
335 "ref: refs/heads/feature-branch\n",
336 )
337 .unwrap();
338
339 let result = resolve_git_branch(&tmp);
340 assert_eq!(result.as_deref(), Some("feature-branch"));
341
342 let _ = std::fs::remove_dir_all(&tmp);
343 }
344
345 #[test]
346 fn test_resolve_git_branch_detached() {
347 let tmp = std::env::temp_dir().join(format!("rab-test-{}", uuid::Uuid::new_v4()));
348 std::fs::create_dir_all(&tmp.join(".git")).unwrap();
349 std::fs::write(&tmp.join(".git").join("HEAD"), "abc123def456\n").unwrap();
350
351 let result = resolve_git_branch(&tmp);
352 assert_eq!(result.as_deref(), Some("detached"));
353
354 let _ = std::fs::remove_dir_all(&tmp);
355 }
356
357 #[test]
358 fn test_resolve_git_branch_no_git() {
359 let tmp = std::env::temp_dir().join(format!("rab-test-{}", uuid::Uuid::new_v4()));
360 std::fs::create_dir_all(&tmp).unwrap();
361
362 let result = resolve_git_branch(&tmp);
363 assert!(result.is_none());
364
365 let _ = std::fs::remove_dir_all(&tmp);
366 }
367}
368
369struct GitPaths {
370 _repo_dir: PathBuf,
371 head_path: PathBuf,
372}
373
374fn find_git_paths(cwd: &Path) -> Option<GitPaths> {
376 let mut dir = Some(cwd.to_path_buf());
377 while let Some(ref d) = dir {
378 let git_path = d.join(".git");
379 if git_path.exists() {
380 if git_path.is_file() {
381 let content = fs::read_to_string(&git_path).ok()?;
383 let content = content.trim();
384 if let Some(git_dir_str) = content.strip_prefix("gitdir: ") {
385 let git_dir = d.join(git_dir_str);
386 let head_path = git_dir.join("HEAD");
387 if head_path.exists() {
388 return Some(GitPaths {
389 _repo_dir: d.clone(),
390 head_path,
391 });
392 }
393 }
394 } else if git_path.is_dir() {
395 let head_path = git_path.join("HEAD");
397 if head_path.exists() {
398 return Some(GitPaths {
399 _repo_dir: d.clone(),
400 head_path,
401 });
402 }
403 }
404 }
405 dir = d.parent().map(|p| p.to_path_buf());
406 }
407 None
408}
409
410fn resolve_git_branch(cwd: &Path) -> Option<String> {
412 let paths = find_git_paths(cwd)?;
413 let content = fs::read_to_string(&paths.head_path).ok()?;
414 let content = content.trim();
415
416 if let Some(branch) = content.strip_prefix("ref: refs/heads/") {
417 if branch == ".invalid" {
418 resolve_branch_with_git(&paths._repo_dir)
420 } else {
421 Some(branch.to_string())
422 }
423 } else {
424 Some("detached".to_string())
426 }
427}
428
429fn resolve_branch_with_git(repo_dir: &Path) -> Option<String> {
431 let output = std::process::Command::new("git")
432 .args([
433 "--no-optional-locks",
434 "symbolic-ref",
435 "--quiet",
436 "--short",
437 "HEAD",
438 ])
439 .current_dir(repo_dir)
440 .output()
441 .ok()?;
442 if output.status.success() {
443 let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
444 if !branch.is_empty() {
445 return Some(branch);
446 }
447 }
448 Some("detached".to_string())
449}