polykit_core/
task_cache.rs1use std::fs;
4use std::path::{Path, PathBuf};
5
6use bincode;
7use serde::{Deserialize, Serialize};
8use xxhash_rust::xxh3::xxh3_64;
9
10use crate::error::{Error, Result};
11use crate::runner::TaskResult;
12
13const TASK_CACHE_VERSION: u32 = 1;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16struct TaskCacheEntry {
17 version: u32,
18 package_name: String,
19 task_name: String,
20 command: String,
21 command_hash: u64,
22 success: bool,
23 stdout: String,
24 stderr: String,
25}
26
27#[derive(Clone)]
29pub struct TaskCache {
30 cache_dir: PathBuf,
31}
32
33impl TaskCache {
34 pub fn new(cache_dir: impl AsRef<Path>) -> Self {
36 Self {
37 cache_dir: cache_dir.as_ref().to_path_buf(),
38 }
39 }
40
41 fn cache_key(package_name: &str, task_name: &str, command: &str) -> String {
43 let mut buffer = Vec::with_capacity(package_name.len() + task_name.len() + command.len());
44 buffer.extend_from_slice(package_name.as_bytes());
45 buffer.extend_from_slice(task_name.as_bytes());
46 buffer.extend_from_slice(command.as_bytes());
47 let hash = xxh3_64(&buffer);
48
49 let safe_package = package_name.replace(['/', '\\', '.', ':'], "_");
50 let safe_task = task_name.replace(['/', '\\', '.', ':'], "_");
51 format!("task_{}_{}_{:x}", safe_package, safe_task, hash)
52 }
53
54 fn get_safe_cache_path(&self, cache_key: &str) -> Result<PathBuf> {
55 let filename = format!("{}.bin", cache_key);
56 let cache_path = self.cache_dir.join(&filename);
57
58 if let Ok(canonical_cache_dir) = self.cache_dir.canonicalize() {
59 if let Ok(canonical_cache_path) = cache_path
60 .canonicalize()
61 .or_else(|_| self.cache_dir.canonicalize().map(|dir| dir.join(&filename)))
62 {
63 if !canonical_cache_path.starts_with(&canonical_cache_dir) {
64 return Err(Error::Adapter {
65 package: "task-cache".to_string(),
66 message: "Invalid cache path detected".to_string(),
67 });
68 }
69 return Ok(cache_path);
70 }
71 }
72
73 Ok(cache_path)
74 }
75
76 pub fn get(
78 &self,
79 package_name: &str,
80 task_name: &str,
81 command: &str,
82 ) -> Result<Option<TaskResult>> {
83 let cache_key = Self::cache_key(package_name, task_name, command);
84 let cache_path = self.get_safe_cache_path(&cache_key)?;
85
86 if !cache_path.exists() {
87 return Ok(None);
88 }
89
90 let compressed = fs::read(&cache_path).map_err(Error::Io)?;
91 let content = if compressed.len() < 1024 {
92 lz4_flex::decompress_size_prepended(&compressed).map_err(|e| Error::Adapter {
93 package: "task-cache".to_string(),
94 message: format!("Failed to decompress task cache (LZ4): {}", e),
95 })?
96 } else {
97 zstd::decode_all(&compressed[..]).map_err(|e| Error::Adapter {
98 package: "task-cache".to_string(),
99 message: format!("Failed to decompress task cache (zstd): {}", e),
100 })?
101 };
102
103 let entry: TaskCacheEntry = bincode::deserialize(&content).map_err(|e| Error::Adapter {
104 package: "task-cache".to_string(),
105 message: format!("Failed to parse task cache: {}", e),
106 })?;
107
108 if entry.version != TASK_CACHE_VERSION {
109 return Ok(None);
110 }
111
112 if entry.package_name != package_name
113 || entry.task_name != task_name
114 || entry.command != command
115 {
116 return Ok(None);
117 }
118
119 let command_hash = xxh3_64(command.as_bytes());
120 if command_hash != entry.command_hash {
121 return Ok(None);
122 }
123
124 Ok(Some(TaskResult {
125 package_name: entry.package_name,
126 task_name: entry.task_name,
127 success: entry.success,
128 stdout: entry.stdout,
129 stderr: entry.stderr,
130 }))
131 }
132
133 pub fn put(
135 &self,
136 package_name: &str,
137 task_name: &str,
138 command: &str,
139 result: &TaskResult,
140 ) -> Result<()> {
141 if !result.success {
142 return Ok(());
143 }
144
145 fs::create_dir_all(&self.cache_dir).map_err(Error::Io)?;
146
147 let cache_key = Self::cache_key(package_name, task_name, command);
148 let cache_path = self.get_safe_cache_path(&cache_key)?;
149
150 let command_hash = xxh3_64(command.as_bytes());
151
152 let entry = TaskCacheEntry {
153 version: TASK_CACHE_VERSION,
154 package_name: package_name.to_string(),
155 task_name: task_name.to_string(),
156 command: command.to_string(),
157 command_hash,
158 success: result.success,
159 stdout: result.stdout.clone(),
160 stderr: result.stderr.clone(),
161 };
162
163 let serialized = bincode::serialize(&entry).map_err(|e| Error::Adapter {
164 package: "task-cache".to_string(),
165 message: format!("Failed to serialize task cache: {}", e),
166 })?;
167
168 let compressed = if serialized.len() < 1024 {
169 lz4_flex::compress_prepend_size(&serialized)
170 } else {
171 zstd::encode_all(&serialized[..], 3).map_err(|e| Error::Adapter {
172 package: "task-cache".to_string(),
173 message: format!("Failed to compress task cache: {}", e),
174 })?
175 };
176
177 fs::write(&cache_path, compressed).map_err(Error::Io)?;
178
179 Ok(())
180 }
181
182 pub fn clear(&self) -> Result<()> {
184 if self.cache_dir.exists() {
185 fs::remove_dir_all(&self.cache_dir).map_err(Error::Io)?;
186 }
187 Ok(())
188 }
189}