rattler_networking/authentication_storage/backends/
file.rs1use std::{
3 collections::BTreeMap,
4 ffi::OsStr,
5 io::BufWriter,
6 path::{Path, PathBuf},
7 sync::{Arc, RwLock},
8};
9
10use crate::{
11 authentication_storage::{AuthenticationStorageError, StorageBackend},
12 Authentication,
13};
14
15#[derive(Clone, Debug)]
16struct FileStorageCache {
17 content: BTreeMap<String, Authentication>,
18}
19
20#[derive(Clone, Debug)]
23pub struct FileStorage {
24 pub path: PathBuf,
26
27 cache: Arc<RwLock<FileStorageCache>>,
31}
32
33#[derive(thiserror::Error, Debug)]
35pub enum FileStorageError {
36 #[error(transparent)]
38 IOError(#[from] std::io::Error),
39
40 #[error("failed to parse {0}: {1}")]
42 JSONError(PathBuf, serde_json::Error),
43}
44
45impl FileStorageCache {
46 pub fn from_path(path: &Path) -> Result<Self, FileStorageError> {
47 match fs_err::read_to_string(path) {
48 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Self {
49 content: BTreeMap::new(),
50 }),
51 Err(e) => Err(FileStorageError::IOError(e)),
52 Ok(content) => {
53 let content = serde_json::from_str(&content)
54 .map_err(|e| FileStorageError::JSONError(path.to_path_buf(), e))?;
55 Ok(Self { content })
56 }
57 }
58 }
59}
60
61impl FileStorage {
62 pub fn from_path(path: PathBuf) -> Result<Self, FileStorageError> {
64 let cache = Arc::new(RwLock::new(FileStorageCache::from_path(&path)?));
66
67 Ok(Self { path, cache })
68 }
69
70 #[cfg(feature = "dirs")]
72 pub fn new() -> Result<Self, FileStorageError> {
73 let path = dirs::home_dir()
74 .unwrap()
75 .join(".rattler")
76 .join("credentials.json");
77 Self::from_path(path)
78 }
79
80 fn read_json(&self) -> Result<BTreeMap<String, Authentication>, FileStorageError> {
83 let new_cache = FileStorageCache::from_path(&self.path)?;
84 let mut cache = self.cache.write().unwrap();
85 cache.content = new_cache.content;
86 Ok(cache.content.clone())
87 }
88
89 fn write_json(&self, dict: &BTreeMap<String, Authentication>) -> Result<(), FileStorageError> {
91 let parent = self
92 .path
93 .parent()
94 .ok_or(FileStorageError::IOError(std::io::Error::new(
95 std::io::ErrorKind::NotFound,
96 "Parent directory not found",
97 )))?;
98 std::fs::create_dir_all(parent)?;
99
100 let prefix = self
101 .path
102 .file_stem()
103 .unwrap_or_else(|| OsStr::new("credentials"));
104 let extension = self
105 .path
106 .extension()
107 .and_then(OsStr::to_str)
108 .unwrap_or("json");
109
110 let mut temp_file = tempfile::Builder::new()
113 .prefix(prefix)
114 .suffix(&format!(".{extension}"))
115 .tempfile_in(parent)?;
116 serde_json::to_writer(BufWriter::new(&mut temp_file), dict)
117 .map_err(std::io::Error::from)?;
118 temp_file
119 .persist(&self.path)
120 .map_err(std::io::Error::from)?;
121
122 let mut cache = self.cache.write().unwrap();
124 cache.content = dict.clone();
125
126 Ok(())
127 }
128}
129
130impl StorageBackend for FileStorage {
131 fn store(
132 &self,
133 host: &str,
134 authentication: &crate::Authentication,
135 ) -> Result<(), AuthenticationStorageError> {
136 let mut dict = self.read_json()?;
137 dict.insert(host.to_string(), authentication.clone());
138 Ok(self.write_json(&dict)?)
139 }
140
141 fn get(&self, host: &str) -> Result<Option<crate::Authentication>, AuthenticationStorageError> {
142 let cache = self.cache.read().unwrap();
143 Ok(cache.content.get(host).cloned())
144 }
145
146 fn delete(&self, host: &str) -> Result<(), AuthenticationStorageError> {
147 let mut dict = self.read_json()?;
148 if dict.remove(host).is_some() {
149 Ok(self.write_json(&dict)?)
150 } else {
151 Ok(())
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use std::{fs, io::Write};
159
160 use insta::assert_snapshot;
161 use tempfile::tempdir;
162
163 use super::*;
164
165 #[test]
166 fn test_file_storage() {
167 let file = tempdir().unwrap();
168 let path = file.path().join("test.json");
169
170 let storage = FileStorage::from_path(path.clone()).unwrap();
171
172 assert_eq!(storage.get("test").unwrap(), None);
173
174 storage
175 .store("test", &Authentication::CondaToken("password".to_string()))
176 .unwrap();
177 assert_eq!(
178 storage.get("test").unwrap(),
179 Some(Authentication::CondaToken("password".to_string()))
180 );
181
182 storage
183 .store(
184 "bearer",
185 &Authentication::BearerToken("password".to_string()),
186 )
187 .unwrap();
188 storage
189 .store(
190 "basic",
191 &Authentication::BasicHTTP {
192 username: "user".to_string(),
193 password: "password".to_string(),
194 },
195 )
196 .unwrap();
197
198 assert_snapshot!(fs::read_to_string(&path).unwrap());
199
200 storage.delete("test").unwrap();
201 assert_eq!(storage.get("test").unwrap(), None);
202
203 let mut file = std::fs::File::create(&path).unwrap();
204 file.write_all(b"invalid json").unwrap();
205
206 assert!(FileStorage::from_path(path.clone()).is_err());
207 }
208}