1pub mod sqlite;
7
8pub use sqlite::SqliteStorage;
9
10use crate::error::Result;
11use crate::types::{DownloadId, DownloadStatus};
12use async_trait::async_trait;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum SegmentState {
17 Pending,
19 Downloading,
21 Completed,
23 Failed { error: String, retries: u32 },
25}
26
27#[derive(Debug, Clone)]
29pub struct Segment {
30 pub index: usize,
32 pub start: u64,
34 pub end: u64,
36 pub downloaded: u64,
38 pub state: SegmentState,
40}
41
42impl Segment {
43 pub fn new(index: usize, start: u64, end: u64) -> Self {
45 Self {
46 index,
47 start,
48 end,
49 downloaded: 0,
50 state: SegmentState::Pending,
51 }
52 }
53
54 pub fn size(&self) -> u64 {
56 self.end - self.start + 1
57 }
58
59 pub fn is_complete(&self) -> bool {
61 self.state == SegmentState::Completed
62 }
63
64 pub fn remaining(&self) -> u64 {
66 self.size().saturating_sub(self.downloaded)
67 }
68}
69
70#[async_trait]
75pub trait Storage: Send + Sync {
76 async fn save_download(&self, status: &DownloadStatus) -> Result<()>;
78
79 async fn load_download(&self, id: DownloadId) -> Result<Option<DownloadStatus>>;
81
82 async fn load_all(&self) -> Result<Vec<DownloadStatus>>;
84
85 async fn delete_download(&self, id: DownloadId) -> Result<()>;
87
88 async fn save_segments(&self, id: DownloadId, segments: &[Segment]) -> Result<()>;
90
91 async fn load_segments(&self, id: DownloadId) -> Result<Vec<Segment>>;
93
94 async fn delete_segments(&self, id: DownloadId) -> Result<()>;
96
97 async fn health_check(&self) -> Result<()>;
99
100 async fn compact(&self) -> Result<()>;
102}
103
104#[derive(Debug, Default)]
106pub struct MemoryStorage {
107 downloads: parking_lot::RwLock<std::collections::HashMap<DownloadId, DownloadStatus>>,
108 segments: parking_lot::RwLock<std::collections::HashMap<DownloadId, Vec<Segment>>>,
109}
110
111impl MemoryStorage {
112 pub fn new() -> Self {
113 Self::default()
114 }
115}
116
117#[async_trait]
118impl Storage for MemoryStorage {
119 async fn save_download(&self, status: &DownloadStatus) -> Result<()> {
120 self.downloads.write().insert(status.id, status.clone());
121 Ok(())
122 }
123
124 async fn load_download(&self, id: DownloadId) -> Result<Option<DownloadStatus>> {
125 Ok(self.downloads.read().get(&id).cloned())
126 }
127
128 async fn load_all(&self) -> Result<Vec<DownloadStatus>> {
129 Ok(self.downloads.read().values().cloned().collect())
130 }
131
132 async fn delete_download(&self, id: DownloadId) -> Result<()> {
133 self.downloads.write().remove(&id);
134 self.segments.write().remove(&id);
135 Ok(())
136 }
137
138 async fn save_segments(&self, id: DownloadId, segments: &[Segment]) -> Result<()> {
139 self.segments.write().insert(id, segments.to_vec());
140 Ok(())
141 }
142
143 async fn load_segments(&self, id: DownloadId) -> Result<Vec<Segment>> {
144 Ok(self.segments.read().get(&id).cloned().unwrap_or_default())
145 }
146
147 async fn delete_segments(&self, id: DownloadId) -> Result<()> {
148 self.segments.write().remove(&id);
149 Ok(())
150 }
151
152 async fn health_check(&self) -> Result<()> {
153 Ok(())
154 }
155
156 async fn compact(&self) -> Result<()> {
157 Ok(())
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::types::{DownloadKind, DownloadMetadata, DownloadProgress, DownloadState};
165 use chrono::Utc;
166 use std::path::PathBuf;
167
168 fn create_test_status() -> DownloadStatus {
169 DownloadStatus {
170 id: DownloadId::new(),
171 kind: DownloadKind::Http,
172 state: DownloadState::Downloading,
173 priority: crate::priority_queue::DownloadPriority::Normal,
174 progress: DownloadProgress::default(),
175 metadata: DownloadMetadata {
176 name: "test.zip".to_string(),
177 url: Some("https://example.com/test.zip".to_string()),
178 magnet_uri: None,
179 info_hash: None,
180 save_dir: PathBuf::from("/tmp"),
181 filename: Some("test.zip".to_string()),
182 user_agent: None,
183 referer: None,
184 headers: vec![],
185 cookies: Vec::new(),
186 checksum: None,
187 mirrors: Vec::new(),
188 etag: None,
189 last_modified: None,
190 },
191 torrent_info: None,
192 peers: None,
193 created_at: Utc::now(),
194 completed_at: None,
195 }
196 }
197
198 #[tokio::test]
199 async fn test_memory_storage() {
200 let storage = MemoryStorage::new();
201 let status = create_test_status();
202 let id = status.id;
203
204 storage.save_download(&status).await.unwrap();
206
207 let loaded = storage.load_download(id).await.unwrap();
209 assert!(loaded.is_some());
210 assert_eq!(loaded.unwrap().id, id);
211
212 let all = storage.load_all().await.unwrap();
214 assert_eq!(all.len(), 1);
215
216 storage.delete_download(id).await.unwrap();
218 let loaded = storage.load_download(id).await.unwrap();
219 assert!(loaded.is_none());
220 }
221
222 #[tokio::test]
223 async fn test_segment_storage() {
224 let storage = MemoryStorage::new();
225 let id = DownloadId::new();
226
227 let segments = vec![
228 Segment::new(0, 0, 999),
229 Segment::new(1, 1000, 1999),
230 Segment::new(2, 2000, 2999),
231 ];
232
233 storage.save_segments(id, &segments).await.unwrap();
235
236 let loaded = storage.load_segments(id).await.unwrap();
238 assert_eq!(loaded.len(), 3);
239 assert_eq!(loaded[0].start, 0);
240 assert_eq!(loaded[1].start, 1000);
241 assert_eq!(loaded[2].start, 2000);
242
243 storage.delete_segments(id).await.unwrap();
245 let loaded = storage.load_segments(id).await.unwrap();
246 assert!(loaded.is_empty());
247 }
248
249 #[test]
250 fn test_segment_size() {
251 let segment = Segment::new(0, 0, 999);
252 assert_eq!(segment.size(), 1000);
253
254 let segment = Segment::new(1, 1000, 1999);
255 assert_eq!(segment.size(), 1000);
256 }
257}