1use crate::error::{SpatialError, SpatialResult};
2use std::path::{Path, PathBuf};
3use tokio::io::AsyncWriteExt;
4
5pub fn get_checkpoint_dir() -> SpatialResult<PathBuf> {
6 if let Ok(custom_dir) = std::env::var("SPATIAL_MAKER_CHECKPOINTS") {
7 Ok(PathBuf::from(custom_dir))
8 } else {
9 let home = dirs::home_dir().ok_or_else(|| {
10 SpatialError::ConfigError("Could not determine home directory".to_string())
11 })?;
12 Ok(home.join(".spatial-maker").join("checkpoints"))
13 }
14}
15
16#[derive(Clone, Debug)]
17pub struct ModelMetadata {
18 pub name: String,
19 pub filename: String,
20 pub url: String,
21 pub size_mb: u32,
22}
23
24impl ModelMetadata {
25 pub fn coreml(encoder_size: &str) -> SpatialResult<Self> {
26 match encoder_size {
27 "s" | "small" => Ok(ModelMetadata {
28 name: "depth-anything-v2-small".to_string(),
29 filename: "DepthAnythingV2SmallF16.mlpackage".to_string(),
30 url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2SmallF16.mlpackage.tar.gz".to_string(),
31 size_mb: 48,
32 }),
33 "b" | "base" => Ok(ModelMetadata {
34 name: "depth-anything-v2-base".to_string(),
35 filename: "DepthAnythingV2BaseF16.mlpackage".to_string(),
36 url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2BaseF16.mlpackage.tar.gz".to_string(),
37 size_mb: 186,
38 }),
39 "l" | "large" => Ok(ModelMetadata {
40 name: "depth-anything-v2-large".to_string(),
41 filename: "DepthAnythingV2LargeF16.mlpackage".to_string(),
42 url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2LargeF16.mlpackage.tar.gz".to_string(),
43 size_mb: 638,
44 }),
45 other => Err(SpatialError::ConfigError(
46 format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
47 )),
48 }
49 }
50
51 #[cfg(feature = "onnx")]
52 pub fn onnx(encoder_size: &str) -> SpatialResult<Self> {
53 match encoder_size {
54 "s" | "small" => Ok(ModelMetadata {
55 name: "depth-anything-v2-small".to_string(),
56 filename: "depth_anything_v2_small.onnx".to_string(),
57 url: "https://huggingface.co/onnx-community/depth-anything-v2-small/resolve/main/onnx/model.onnx".to_string(),
58 size_mb: 99,
59 }),
60 "b" | "base" => Ok(ModelMetadata {
61 name: "depth-anything-v2-base".to_string(),
62 filename: "depth_anything_v2_base.onnx".to_string(),
63 url: "https://huggingface.co/onnx-community/depth-anything-v2-base/resolve/main/onnx/model.onnx".to_string(),
64 size_mb: 380,
65 }),
66 "l" | "large" => Ok(ModelMetadata {
67 name: "depth-anything-v2-large".to_string(),
68 filename: "depth_anything_v2_large.onnx".to_string(),
69 url: "https://huggingface.co/onnx-community/depth-anything-v2-large/resolve/main/onnx/model.onnx".to_string(),
70 size_mb: 1300,
71 }),
72 other => Err(SpatialError::ConfigError(
73 format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
74 )),
75 }
76 }
77}
78
79pub fn find_model(encoder_size: &str) -> SpatialResult<PathBuf> {
80 let checkpoint_dir = get_checkpoint_dir()?;
81
82 #[cfg(all(target_os = "macos", feature = "coreml"))]
83 {
84 let meta = ModelMetadata::coreml(encoder_size)?;
85 let model_path = checkpoint_dir.join(&meta.filename);
86 if model_path.exists() {
87 return Ok(model_path);
88 }
89 }
90
91 #[cfg(feature = "onnx")]
92 {
93 let meta = ModelMetadata::onnx(encoder_size)?;
94 let model_path = checkpoint_dir.join(&meta.filename);
95 if model_path.exists() {
96 return Ok(model_path);
97 }
98 }
99
100 let dev_paths = [
102 PathBuf::from("checkpoints"),
103 dirs::home_dir()
104 .unwrap_or_default()
105 .join(".spatial-maker")
106 .join("checkpoints"),
107 ];
108
109 for dir in &dev_paths {
110 if dir.exists() {
111 if let Ok(entries) = std::fs::read_dir(dir) {
112 for entry in entries.flatten() {
113 let name = entry.file_name().to_string_lossy().to_string();
114 if name.contains("DepthAnything") || name.contains("depth_anything") {
115 let lower_size = encoder_size.to_lowercase();
116 let name_lower = name.to_lowercase();
117 let matches = match lower_size.as_str() {
118 "s" | "small" => name_lower.contains("small"),
119 "b" | "base" => name_lower.contains("base"),
120 "l" | "large" => name_lower.contains("large"),
121 _ => false,
122 };
123 if matches {
124 return Ok(entry.path());
125 }
126 }
127 }
128 }
129 }
130 }
131
132 Err(SpatialError::ModelError(format!(
133 "Model not found for encoder size '{}'. Run download first.",
134 encoder_size
135 )))
136}
137
138pub fn model_exists(encoder_size: &str) -> bool {
139 find_model(encoder_size).is_ok()
140}
141
142pub async fn ensure_model_exists<F>(
143 encoder_size: &str,
144 progress_fn: Option<F>,
145) -> SpatialResult<PathBuf>
146where
147 F: FnMut(u64, u64),
148{
149 if let Ok(path) = find_model(encoder_size) {
150 return Ok(path);
151 }
152
153 let checkpoint_dir = get_checkpoint_dir()?;
154 tokio::fs::create_dir_all(&checkpoint_dir)
155 .await
156 .map_err(|e| {
157 SpatialError::IoError(format!("Failed to create checkpoint directory: {}", e))
158 })?;
159
160 #[cfg(all(target_os = "macos", feature = "coreml"))]
161 {
162 let meta = ModelMetadata::coreml(encoder_size)?;
163 let model_path = checkpoint_dir.join(&meta.filename);
164 download_model(&meta, &model_path, progress_fn).await?;
165 return Ok(model_path);
166 }
167
168 #[cfg(all(feature = "onnx", not(all(target_os = "macos", feature = "coreml"))))]
169 {
170 let meta = ModelMetadata::onnx(encoder_size)?;
171 let model_path = checkpoint_dir.join(&meta.filename);
172 download_model(&meta, &model_path, progress_fn).await?;
173 return Ok(model_path);
174 }
175
176 #[cfg(not(any(all(target_os = "macos", feature = "coreml"), feature = "onnx")))]
177 {
178 let _ = progress_fn;
179 Err(SpatialError::ConfigError(
180 "No depth backend enabled. Enable 'coreml' (macOS) or 'onnx' feature.".to_string(),
181 ))
182 }
183}
184
185async fn download_model<F>(
186 metadata: &ModelMetadata,
187 destination: &Path,
188 mut progress_fn: Option<F>,
189) -> SpatialResult<()>
190where
191 F: FnMut(u64, u64),
192{
193 tracing::info!("Downloading model: {} from {}", metadata.name, metadata.url);
194
195 let response = reqwest::get(&metadata.url)
196 .await
197 .map_err(|e| SpatialError::Other(format!("Failed to download model: {}", e)))?;
198
199 let total_bytes = response
200 .content_length()
201 .unwrap_or(metadata.size_mb as u64 * 1_000_000);
202
203 let is_tar_gz = metadata.url.ends_with(".tar.gz");
204
205 if is_tar_gz {
206 let temp_path = destination.with_extension("tar.gz");
207 let mut file = tokio::fs::File::create(&temp_path)
208 .await
209 .map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
210
211 let mut downloaded = 0u64;
212 let mut stream = response.bytes_stream();
213 use futures_util::StreamExt;
214
215 while let Some(chunk) = stream.next().await {
216 let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
217 file.write_all(&chunk)
218 .await
219 .map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
220 downloaded += chunk.len() as u64;
221 if let Some(ref mut f) = progress_fn {
222 f(downloaded, total_bytes);
223 }
224 }
225 drop(file);
226
227 let parent = destination
228 .parent()
229 .ok_or_else(|| SpatialError::IoError("Invalid destination path".to_string()))?;
230
231 let output = std::process::Command::new("tar")
232 .args(&["xzf"])
233 .arg(&temp_path)
234 .arg("-C")
235 .arg(parent)
236 .output()
237 .map_err(|e| SpatialError::IoError(format!("Failed to extract tar.gz: {}", e)))?;
238
239 if !output.status.success() {
240 let stderr = String::from_utf8_lossy(&output.stderr);
241 return Err(SpatialError::IoError(format!("tar extraction failed: {}", stderr)));
242 }
243
244 let _ = tokio::fs::remove_file(&temp_path).await;
245 } else {
246 let mut file = tokio::fs::File::create(destination)
247 .await
248 .map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
249
250 let mut downloaded = 0u64;
251 let mut stream = response.bytes_stream();
252 use futures_util::StreamExt;
253
254 while let Some(chunk) = stream.next().await {
255 let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
256 file.write_all(&chunk)
257 .await
258 .map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
259 downloaded += chunk.len() as u64;
260 if let Some(ref mut f) = progress_fn {
261 f(downloaded, total_bytes);
262 }
263 }
264 }
265
266 tracing::info!("Model downloaded: {:?}", destination);
267 Ok(())
268}