gosh_dl/storage/
mod.rs

1//! Storage Module
2//!
3//! This module handles persistent storage for download state and session data.
4//! Uses SQLite with WAL mode for crash-safe atomic commits.
5
6pub mod sqlite;
7
8pub use sqlite::SqliteStorage;
9
10use crate::error::Result;
11use crate::types::{DownloadId, DownloadStatus};
12use async_trait::async_trait;
13
14/// Segment state for HTTP multi-connection downloads
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum SegmentState {
17    /// Segment is waiting to be downloaded
18    Pending,
19    /// Segment is currently being downloaded
20    Downloading,
21    /// Segment completed successfully
22    Completed,
23    /// Segment failed and may be retried
24    Failed { error: String, retries: u32 },
25}
26
27/// Represents a download segment for multi-connection HTTP downloads
28#[derive(Debug, Clone)]
29pub struct Segment {
30    /// Segment index (0-based)
31    pub index: usize,
32    /// Start byte offset (inclusive)
33    pub start: u64,
34    /// End byte offset (inclusive)
35    pub end: u64,
36    /// Bytes downloaded for this segment
37    pub downloaded: u64,
38    /// Current state
39    pub state: SegmentState,
40}
41
42impl Segment {
43    /// Create a new pending segment
44    pub fn new(index: usize, start: u64, end: u64) -> Self {
45        Self {
46            index,
47            start,
48            end,
49            downloaded: 0,
50            state: SegmentState::Pending,
51        }
52    }
53
54    /// Get the total size of this segment
55    pub fn size(&self) -> u64 {
56        self.end - self.start + 1
57    }
58
59    /// Check if segment is complete
60    pub fn is_complete(&self) -> bool {
61        self.state == SegmentState::Completed
62    }
63
64    /// Get remaining bytes to download
65    pub fn remaining(&self) -> u64 {
66        self.size().saturating_sub(self.downloaded)
67    }
68}
69
70/// Storage trait for persisting download state
71///
72/// Implementations of this trait handle storing and retrieving download
73/// state to allow resume after crashes or restarts.
74#[async_trait]
75pub trait Storage: Send + Sync {
76    /// Save or update a download's status
77    async fn save_download(&self, status: &DownloadStatus) -> Result<()>;
78
79    /// Load a download by ID
80    async fn load_download(&self, id: DownloadId) -> Result<Option<DownloadStatus>>;
81
82    /// Load all downloads
83    async fn load_all(&self) -> Result<Vec<DownloadStatus>>;
84
85    /// Delete a download record
86    async fn delete_download(&self, id: DownloadId) -> Result<()>;
87
88    /// Save segment state for an HTTP download
89    async fn save_segments(&self, id: DownloadId, segments: &[Segment]) -> Result<()>;
90
91    /// Load segment state for an HTTP download
92    async fn load_segments(&self, id: DownloadId) -> Result<Vec<Segment>>;
93
94    /// Delete segment state for a download
95    async fn delete_segments(&self, id: DownloadId) -> Result<()>;
96
97    /// Check if database is healthy
98    async fn health_check(&self) -> Result<()>;
99
100    /// Compact/vacuum the database
101    async fn compact(&self) -> Result<()>;
102}
103
104/// In-memory storage for testing
105#[derive(Debug, Default)]
106pub struct MemoryStorage {
107    downloads: parking_lot::RwLock<std::collections::HashMap<DownloadId, DownloadStatus>>,
108    segments: parking_lot::RwLock<std::collections::HashMap<DownloadId, Vec<Segment>>>,
109}
110
111impl MemoryStorage {
112    pub fn new() -> Self {
113        Self::default()
114    }
115}
116
117#[async_trait]
118impl Storage for MemoryStorage {
119    async fn save_download(&self, status: &DownloadStatus) -> Result<()> {
120        self.downloads.write().insert(status.id, status.clone());
121        Ok(())
122    }
123
124    async fn load_download(&self, id: DownloadId) -> Result<Option<DownloadStatus>> {
125        Ok(self.downloads.read().get(&id).cloned())
126    }
127
128    async fn load_all(&self) -> Result<Vec<DownloadStatus>> {
129        Ok(self.downloads.read().values().cloned().collect())
130    }
131
132    async fn delete_download(&self, id: DownloadId) -> Result<()> {
133        self.downloads.write().remove(&id);
134        self.segments.write().remove(&id);
135        Ok(())
136    }
137
138    async fn save_segments(&self, id: DownloadId, segments: &[Segment]) -> Result<()> {
139        self.segments.write().insert(id, segments.to_vec());
140        Ok(())
141    }
142
143    async fn load_segments(&self, id: DownloadId) -> Result<Vec<Segment>> {
144        Ok(self.segments.read().get(&id).cloned().unwrap_or_default())
145    }
146
147    async fn delete_segments(&self, id: DownloadId) -> Result<()> {
148        self.segments.write().remove(&id);
149        Ok(())
150    }
151
152    async fn health_check(&self) -> Result<()> {
153        Ok(())
154    }
155
156    async fn compact(&self) -> Result<()> {
157        Ok(())
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::types::{DownloadKind, DownloadMetadata, DownloadProgress, DownloadState};
165    use chrono::Utc;
166    use std::path::PathBuf;
167
168    fn create_test_status() -> DownloadStatus {
169        DownloadStatus {
170            id: DownloadId::new(),
171            kind: DownloadKind::Http,
172            state: DownloadState::Downloading,
173            priority: crate::priority_queue::DownloadPriority::Normal,
174            progress: DownloadProgress::default(),
175            metadata: DownloadMetadata {
176                name: "test.zip".to_string(),
177                url: Some("https://example.com/test.zip".to_string()),
178                magnet_uri: None,
179                info_hash: None,
180                save_dir: PathBuf::from("/tmp"),
181                filename: Some("test.zip".to_string()),
182                user_agent: None,
183                referer: None,
184                headers: vec![],
185                cookies: Vec::new(),
186                checksum: None,
187                mirrors: Vec::new(),
188                etag: None,
189                last_modified: None,
190            },
191            torrent_info: None,
192            peers: None,
193            created_at: Utc::now(),
194            completed_at: None,
195        }
196    }
197
198    #[tokio::test]
199    async fn test_memory_storage() {
200        let storage = MemoryStorage::new();
201        let status = create_test_status();
202        let id = status.id;
203
204        // Save
205        storage.save_download(&status).await.unwrap();
206
207        // Load
208        let loaded = storage.load_download(id).await.unwrap();
209        assert!(loaded.is_some());
210        assert_eq!(loaded.unwrap().id, id);
211
212        // Load all
213        let all = storage.load_all().await.unwrap();
214        assert_eq!(all.len(), 1);
215
216        // Delete
217        storage.delete_download(id).await.unwrap();
218        let loaded = storage.load_download(id).await.unwrap();
219        assert!(loaded.is_none());
220    }
221
222    #[tokio::test]
223    async fn test_segment_storage() {
224        let storage = MemoryStorage::new();
225        let id = DownloadId::new();
226
227        let segments = vec![
228            Segment::new(0, 0, 999),
229            Segment::new(1, 1000, 1999),
230            Segment::new(2, 2000, 2999),
231        ];
232
233        // Save segments
234        storage.save_segments(id, &segments).await.unwrap();
235
236        // Load segments
237        let loaded = storage.load_segments(id).await.unwrap();
238        assert_eq!(loaded.len(), 3);
239        assert_eq!(loaded[0].start, 0);
240        assert_eq!(loaded[1].start, 1000);
241        assert_eq!(loaded[2].start, 2000);
242
243        // Delete segments
244        storage.delete_segments(id).await.unwrap();
245        let loaded = storage.load_segments(id).await.unwrap();
246        assert!(loaded.is_empty());
247    }
248
249    #[test]
250    fn test_segment_size() {
251        let segment = Segment::new(0, 0, 999);
252        assert_eq!(segment.size(), 1000);
253
254        let segment = Segment::new(1, 1000, 1999);
255        assert_eq!(segment.size(), 1000);
256    }
257}