jj_ryu/tracking/
pr_cache.rs1use super::storage::resolve_repo_path;
7use crate::error::{Error, Result};
8use crate::types::PullRequest;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::fs;
12use std::path::{Path, PathBuf};
13
14pub const PR_CACHE_VERSION: u32 = 1;
16
17const PR_CACHE_FILE: &str = "pr_cache.toml";
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22pub struct CachedPr {
23 pub bookmark: String,
25 pub number: u64,
27 pub url: String,
29 pub remote: String,
31 pub updated_at: DateTime<Utc>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, Default)]
37pub struct PrCache {
38 pub version: u32,
40 #[serde(default)]
42 pub prs: Vec<CachedPr>,
43}
44
45impl PrCache {
46 pub const fn new() -> Self {
48 Self {
49 version: PR_CACHE_VERSION,
50 prs: Vec::new(),
51 }
52 }
53
54 pub fn get(&self, bookmark: &str) -> Option<&CachedPr> {
56 self.prs.iter().find(|p| p.bookmark == bookmark)
57 }
58
59 pub fn upsert(&mut self, bookmark: &str, pr: &PullRequest, remote: &str) {
61 let entry = CachedPr {
62 bookmark: bookmark.to_string(),
63 number: pr.number,
64 url: pr.html_url.clone(),
65 remote: remote.to_string(),
66 updated_at: Utc::now(),
67 };
68
69 if let Some(existing) = self.prs.iter_mut().find(|p| p.bookmark == bookmark) {
70 *existing = entry;
71 } else {
72 self.prs.push(entry);
73 }
74 }
75
76 pub fn remove(&mut self, bookmark: &str) -> bool {
78 let len_before = self.prs.len();
79 self.prs.retain(|p| p.bookmark != bookmark);
80 self.prs.len() < len_before
81 }
82
83 pub fn retain_bookmarks(&mut self, bookmarks: &[&str]) {
85 self.prs
86 .retain(|p| bookmarks.contains(&p.bookmark.as_str()));
87 }
88}
89
90pub fn pr_cache_path(workspace_root: &Path) -> PathBuf {
92 resolve_repo_path(workspace_root)
93 .join("ryu")
94 .join(PR_CACHE_FILE)
95}
96
97pub fn load_pr_cache(workspace_root: &Path) -> Result<PrCache> {
101 let path = pr_cache_path(workspace_root);
102
103 if !path.exists() {
104 return Ok(PrCache::new());
105 }
106
107 let content = fs::read_to_string(&path)
108 .map_err(|e| Error::Tracking(format!("failed to read {}: {e}", path.display())))?;
109
110 let cache: PrCache = toml::from_str(&content)
111 .map_err(|e| Error::Tracking(format!("failed to parse {}: {e}", path.display())))?;
112
113 Ok(cache)
114}
115
116pub fn save_pr_cache(workspace_root: &Path, cache: &PrCache) -> Result<()> {
120 let path = pr_cache_path(workspace_root);
121 let dir = path.parent().expect("path has parent");
122
123 if !dir.exists() {
125 fs::create_dir_all(dir)
126 .map_err(|e| Error::Tracking(format!("failed to create {}: {e}", dir.display())))?;
127 }
128
129 let mut cache_to_save = cache.clone();
131 cache_to_save.version = PR_CACHE_VERSION;
132
133 let content = toml::to_string_pretty(&cache_to_save)
134 .map_err(|e| Error::Tracking(format!("failed to serialize PR cache: {e}")))?;
135
136 let content_with_header = format!(
138 "# PR association cache - regenerated from platform API on submit\n\
139 # Safe to delete; will be rebuilt on next submit\n\n{content}"
140 );
141
142 fs::write(&path, content_with_header)
143 .map_err(|e| Error::Tracking(format!("failed to write {}: {e}", path.display())))?;
144
145 Ok(())
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use tempfile::TempDir;
152
153 fn setup_fake_jj_workspace() -> TempDir {
154 let temp = TempDir::new().unwrap();
155 std::fs::create_dir_all(temp.path().join(".jj").join("repo")).unwrap();
156 temp
157 }
158
159 fn make_test_pr(number: u64) -> PullRequest {
160 PullRequest {
161 number,
162 html_url: format!("https://github.com/owner/repo/pull/{number}"),
163 base_ref: "main".to_string(),
164 head_ref: "feat".to_string(),
165 title: "Test PR".to_string(),
166 node_id: None,
167 is_draft: false,
168 }
169 }
170
171 #[test]
172 fn test_pr_cache_path() {
173 let temp = setup_fake_jj_workspace();
174 let path = pr_cache_path(temp.path());
175 assert!(path.ends_with(".jj/repo/ryu/pr_cache.toml"));
176 }
177
178 #[test]
179 fn test_load_missing_file_returns_empty() {
180 let temp = setup_fake_jj_workspace();
181 let cache = load_pr_cache(temp.path()).unwrap();
182 assert!(cache.prs.is_empty());
183 assert_eq!(cache.version, PR_CACHE_VERSION);
184 }
185
186 #[test]
187 fn test_upsert_and_get() {
188 let mut cache = PrCache::new();
189 let pr = make_test_pr(123);
190
191 cache.upsert("feat-auth", &pr, "origin");
192
193 let cached = cache.get("feat-auth").unwrap();
194 assert_eq!(cached.number, 123);
195 assert_eq!(cached.remote, "origin");
196 assert!(cached.url.contains("123"));
197
198 let pr2 = make_test_pr(456);
200 cache.upsert("feat-auth", &pr2, "upstream");
201
202 let cached = cache.get("feat-auth").unwrap();
203 assert_eq!(cached.number, 456);
204 assert_eq!(cached.remote, "upstream");
205 }
206
207 #[test]
208 fn test_remove() {
209 let mut cache = PrCache::new();
210 cache.upsert("feat-auth", &make_test_pr(123), "origin");
211 cache.upsert("feat-db", &make_test_pr(124), "origin");
212
213 assert!(cache.remove("feat-auth"));
214 assert!(cache.get("feat-auth").is_none());
215 assert!(cache.get("feat-db").is_some());
216
217 assert!(!cache.remove("feat-auth")); }
219
220 #[test]
221 fn test_retain_bookmarks() {
222 let mut cache = PrCache::new();
223 cache.upsert("feat-auth", &make_test_pr(123), "origin");
224 cache.upsert("feat-db", &make_test_pr(124), "origin");
225 cache.upsert("feat-ui", &make_test_pr(125), "origin");
226
227 cache.retain_bookmarks(&["feat-auth", "feat-ui"]);
228
229 assert!(cache.get("feat-auth").is_some());
230 assert!(cache.get("feat-db").is_none());
231 assert!(cache.get("feat-ui").is_some());
232 }
233
234 #[test]
235 fn test_roundtrip_serialization() {
236 let temp = setup_fake_jj_workspace();
237
238 let mut cache = PrCache::new();
239 cache.upsert("feat-auth", &make_test_pr(123), "origin");
240 cache.upsert("feat-db", &make_test_pr(124), "upstream");
241
242 save_pr_cache(temp.path(), &cache).unwrap();
243
244 let loaded = load_pr_cache(temp.path()).unwrap();
245 assert_eq!(loaded.prs.len(), 2);
246
247 let auth = loaded.get("feat-auth").unwrap();
248 assert_eq!(auth.number, 123);
249 assert_eq!(auth.remote, "origin");
250
251 let db = loaded.get("feat-db").unwrap();
252 assert_eq!(db.number, 124);
253 assert_eq!(db.remote, "upstream");
254 }
255
256 #[test]
257 fn test_file_contains_header_comment() {
258 let temp = setup_fake_jj_workspace();
259 let cache = PrCache::new();
260 save_pr_cache(temp.path(), &cache).unwrap();
261
262 let content = fs::read_to_string(pr_cache_path(temp.path())).unwrap();
263 assert!(content.contains("PR association cache"));
264 assert!(content.contains("Safe to delete"));
265 }
266}