Skip to main content

spatial_maker/
model.rs

1use crate::error::{SpatialError, SpatialResult};
2use std::path::{Path, PathBuf};
3use tokio::io::AsyncWriteExt;
4
5pub fn get_checkpoint_dir() -> SpatialResult<PathBuf> {
6	if let Ok(custom_dir) = std::env::var("SPATIAL_MAKER_CHECKPOINTS") {
7		Ok(PathBuf::from(custom_dir))
8	} else {
9		let home = dirs::home_dir().ok_or_else(|| {
10			SpatialError::ConfigError("Could not determine home directory".to_string())
11		})?;
12		Ok(home.join(".spatial-maker").join("checkpoints"))
13	}
14}
15
16#[derive(Clone, Debug)]
17pub struct ModelMetadata {
18	pub name: String,
19	pub filename: String,
20	pub url: String,
21	pub size_mb: u32,
22}
23
24impl ModelMetadata {
25	pub fn coreml(encoder_size: &str) -> SpatialResult<Self> {
26		match encoder_size {
27			"s" | "small" => Ok(ModelMetadata {
28				name: "depth-anything-v2-small".to_string(),
29				filename: "DepthAnythingV2SmallF16.mlpackage".to_string(),
30				url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2SmallF16.mlpackage.tar.gz".to_string(),
31				size_mb: 48,
32			}),
33			"b" | "base" => Ok(ModelMetadata {
34				name: "depth-anything-v2-base".to_string(),
35				filename: "DepthAnythingV2BaseF16.mlpackage".to_string(),
36				url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2BaseF16.mlpackage.tar.gz".to_string(),
37				size_mb: 186,
38			}),
39			"l" | "large" => Ok(ModelMetadata {
40				name: "depth-anything-v2-large".to_string(),
41				filename: "DepthAnythingV2LargeF16.mlpackage".to_string(),
42				url: "https://huggingface.co/mrgnw/depth-anything-v2-coreml/resolve/main/DepthAnythingV2LargeF16.mlpackage.tar.gz".to_string(),
43				size_mb: 638,
44			}),
45			other => Err(SpatialError::ConfigError(
46				format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
47			)),
48		}
49	}
50
51	#[cfg(feature = "onnx")]
52	pub fn onnx(encoder_size: &str) -> SpatialResult<Self> {
53		match encoder_size {
54			"s" | "small" => Ok(ModelMetadata {
55				name: "depth-anything-v2-small".to_string(),
56				filename: "depth_anything_v2_small.onnx".to_string(),
57				url: "https://huggingface.co/onnx-community/depth-anything-v2-small/resolve/main/onnx/model.onnx".to_string(),
58				size_mb: 99,
59			}),
60			"b" | "base" => Ok(ModelMetadata {
61				name: "depth-anything-v2-base".to_string(),
62				filename: "depth_anything_v2_base.onnx".to_string(),
63				url: "https://huggingface.co/onnx-community/depth-anything-v2-base/resolve/main/onnx/model.onnx".to_string(),
64				size_mb: 380,
65			}),
66			"l" | "large" => Ok(ModelMetadata {
67				name: "depth-anything-v2-large".to_string(),
68				filename: "depth_anything_v2_large.onnx".to_string(),
69				url: "https://huggingface.co/onnx-community/depth-anything-v2-large/resolve/main/onnx/model.onnx".to_string(),
70				size_mb: 1300,
71			}),
72			other => Err(SpatialError::ConfigError(
73				format!("Unknown encoder size: '{}'. Use 's', 'b', or 'l'", other)
74			)),
75		}
76	}
77}
78
79pub fn find_model(encoder_size: &str) -> SpatialResult<PathBuf> {
80	let checkpoint_dir = get_checkpoint_dir()?;
81
82	#[cfg(all(target_os = "macos", feature = "coreml"))]
83	{
84		let meta = ModelMetadata::coreml(encoder_size)?;
85		let model_path = checkpoint_dir.join(&meta.filename);
86		if model_path.exists() {
87			return Ok(model_path);
88		}
89	}
90
91	#[cfg(feature = "onnx")]
92	{
93		let meta = ModelMetadata::onnx(encoder_size)?;
94		let model_path = checkpoint_dir.join(&meta.filename);
95		if model_path.exists() {
96			return Ok(model_path);
97		}
98	}
99
100	// Also check development paths
101	let dev_paths = [
102		PathBuf::from("checkpoints"),
103		dirs::home_dir()
104			.unwrap_or_default()
105			.join(".spatial-maker")
106			.join("checkpoints"),
107	];
108
109	for dir in &dev_paths {
110		if dir.exists() {
111			if let Ok(entries) = std::fs::read_dir(dir) {
112				for entry in entries.flatten() {
113					let name = entry.file_name().to_string_lossy().to_string();
114					if name.ends_with(".tar.gz") || name.ends_with(".downloading") {
115						continue;
116					}
117					if name.contains("DepthAnything") || name.contains("depth_anything") {
118						let lower_size = encoder_size.to_lowercase();
119						let name_lower = name.to_lowercase();
120						let matches = match lower_size.as_str() {
121							"s" | "small" => name_lower.contains("small"),
122							"b" | "base" => name_lower.contains("base"),
123							"l" | "large" => name_lower.contains("large"),
124							_ => false,
125						};
126						if matches {
127							return Ok(entry.path());
128						}
129					}
130				}
131			}
132		}
133	}
134
135	Err(SpatialError::ModelError(format!(
136		"Model not found for encoder size '{}'. Run download first.",
137		encoder_size
138	)))
139}
140
141pub fn model_exists(encoder_size: &str) -> bool {
142	find_model(encoder_size).is_ok()
143}
144
145pub async fn ensure_model_exists<F>(
146	encoder_size: &str,
147	progress_fn: Option<F>,
148) -> SpatialResult<PathBuf>
149where
150	F: FnMut(u64, u64),
151{
152	if let Ok(path) = find_model(encoder_size) {
153		return Ok(path);
154	}
155
156	let checkpoint_dir = get_checkpoint_dir()?;
157	tokio::fs::create_dir_all(&checkpoint_dir)
158		.await
159		.map_err(|e| {
160			SpatialError::IoError(format!("Failed to create checkpoint directory: {}", e))
161		})?;
162
163	#[cfg(all(target_os = "macos", feature = "coreml"))]
164	{
165		let meta = ModelMetadata::coreml(encoder_size)?;
166		let model_path = checkpoint_dir.join(&meta.filename);
167		download_model(&meta, &model_path, progress_fn).await?;
168		return Ok(model_path);
169	}
170
171	#[cfg(all(feature = "onnx", not(all(target_os = "macos", feature = "coreml"))))]
172	{
173		let meta = ModelMetadata::onnx(encoder_size)?;
174		let model_path = checkpoint_dir.join(&meta.filename);
175		download_model(&meta, &model_path, progress_fn).await?;
176		return Ok(model_path);
177	}
178
179	#[cfg(not(any(all(target_os = "macos", feature = "coreml"), feature = "onnx")))]
180	{
181		let _ = progress_fn;
182		Err(SpatialError::ConfigError(
183			"No depth backend enabled. Enable 'coreml' (macOS) or 'onnx' feature.".to_string(),
184		))
185	}
186}
187
188async fn download_model<F>(
189	metadata: &ModelMetadata,
190	destination: &Path,
191	mut progress_fn: Option<F>,
192) -> SpatialResult<()>
193where
194	F: FnMut(u64, u64),
195{
196	eprintln!("Downloading model: {} ({} MB)...", metadata.name, metadata.size_mb);
197	tracing::info!("Downloading model: {} from {}", metadata.name, metadata.url);
198
199	let response = reqwest::get(&metadata.url)
200		.await
201		.map_err(|e| SpatialError::Other(format!("Failed to download model: {}", e)))?;
202
203	if !response.status().is_success() {
204		return Err(SpatialError::Other(format!(
205			"Failed to download model: HTTP {} from {}",
206			response.status(),
207			metadata.url
208		)));
209	}
210
211	let total_bytes = response
212		.content_length()
213		.unwrap_or(metadata.size_mb as u64 * 1_000_000);
214
215	let is_tar_gz = metadata.url.ends_with(".tar.gz");
216
217	if is_tar_gz {
218		let temp_path = destination.with_extension("tar.gz");
219		let mut file = tokio::fs::File::create(&temp_path)
220			.await
221			.map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
222
223		let mut downloaded = 0u64;
224		let mut stream = response.bytes_stream();
225		use futures_util::StreamExt;
226
227		let mut last_pct: u64 = 0;
228		while let Some(chunk) = stream.next().await {
229			let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
230			file.write_all(&chunk)
231				.await
232				.map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
233			downloaded += chunk.len() as u64;
234			if let Some(ref mut f) = progress_fn {
235				f(downloaded, total_bytes);
236			}
237			if total_bytes > 0 {
238				let pct = downloaded * 100 / total_bytes;
239				if pct != last_pct {
240					last_pct = pct;
241					eprint!("\rDownloading... {}%", pct);
242				}
243			}
244		}
245		eprintln!();
246		drop(file);
247
248		let parent = destination
249			.parent()
250			.ok_or_else(|| SpatialError::IoError("Invalid destination path".to_string()))?;
251
252		eprintln!("Extracting...");
253		let output = std::process::Command::new("tar")
254			.args(&["xzf"])
255			.arg(&temp_path)
256			.arg("-C")
257			.arg(parent)
258			.output()
259			.map_err(|e| SpatialError::IoError(format!("Failed to extract tar.gz: {}", e)))?;
260
261		if !output.status.success() {
262			let stderr = String::from_utf8_lossy(&output.stderr);
263			return Err(SpatialError::IoError(format!("tar extraction failed: {}", stderr)));
264		}
265
266		let _ = tokio::fs::remove_file(&temp_path).await;
267
268		if !destination.exists() {
269			return Err(SpatialError::ModelError(format!(
270				"Extraction succeeded but model not found at {:?}",
271				destination
272			)));
273		}
274	} else {
275		let mut file = tokio::fs::File::create(destination)
276			.await
277			.map_err(|e| SpatialError::IoError(format!("Failed to create file: {}", e)))?;
278
279		let mut downloaded = 0u64;
280		let mut stream = response.bytes_stream();
281		use futures_util::StreamExt;
282
283		while let Some(chunk) = stream.next().await {
284			let chunk = chunk.map_err(|e| SpatialError::Other(format!("Download interrupted: {}", e)))?;
285			file.write_all(&chunk)
286				.await
287				.map_err(|e| SpatialError::IoError(format!("Failed to write to file: {}", e)))?;
288			downloaded += chunk.len() as u64;
289			if let Some(ref mut f) = progress_fn {
290				f(downloaded, total_bytes);
291			}
292		}
293	}
294
295	tracing::info!("Model downloaded: {:?}", destination);
296	Ok(())
297}