1#![forbid(unsafe_code)]
7#![warn(missing_docs, missing_debug_implementations)]
8
9use std::fmt::Write;
10
11use agent_proxy_rust_storage::{
12 AvailableChannelInfo, AvailableModelInfo, Channel, CostAggregate, CostFilter, CostGroupBy,
13 CostRecord, Model, ModelMapping, Provider, SeedManager, SeedStatus, Storage, StorageError,
14 SubscriptionFee, SwitchLog, TimeRange,
15};
16use async_trait::async_trait;
17use r2d2::Pool;
18use r2d2_sqlite::SqliteConnectionManager;
19use rusqlite::params;
20use secrecy::{ExposeSecret, SecretString};
21use tracing::debug;
22
23mod seed;
24
25const MIGRATION_V1: &str = include_str!("../migrations/001_init.sql");
26
27#[derive(Debug, Clone)]
33pub struct SqliteStorage {
34 pool: Pool<SqliteConnectionManager>,
35}
36
37impl SqliteStorage {
38 pub fn new(path: &std::path::Path) -> Result<Self, StorageError> {
45 let manager = SqliteConnectionManager::file(path);
46 let pool = Pool::builder()
47 .max_size(4)
48 .build(manager)
49 .map_err(|e| StorageError::Connection(format!("failed to create pool: {e}")))?;
50 debug!(path = %path.display(), "SQLite database opened");
51 Ok(Self { pool })
52 }
53
54 pub fn new_in_memory() -> Result<Self, StorageError> {
62 let manager = SqliteConnectionManager::memory();
63 let pool = Pool::builder()
64 .max_size(4)
65 .build(manager)
66 .map_err(|e| StorageError::Connection(format!("failed to create pool: {e}")))?;
67 debug!("SQLite in-memory database opened");
68 Ok(Self { pool })
69 }
70}
71
72impl SqliteStorage {
73 fn now_unix() -> i64 {
74 chrono::Utc::now().timestamp()
75 }
76
77 fn get_pool(&self) -> Pool<SqliteConnectionManager> {
78 self.pool.clone()
79 }
80
81 fn row_to_channel(row: &rusqlite::Row) -> rusqlite::Result<Channel> {
82 Ok(Channel {
83 id: row.get(0)?,
84 name: row.get(1)?,
85 api_key: SecretString::new(row.get::<_, String>(2)?.into_boxed_str()),
86 protocol: row.get(3)?,
87 protocols: row.get::<_, String>(4).unwrap_or_default(),
88 is_builtin: row.get(5)?,
89 enabled: row.get(6)?,
90 created_at: row.get(7)?,
91 updated_at: row.get(8)?,
92 health_status: row.get(9)?,
93 cooldown_until: row.get(10)?,
94 consecutive_failures: row.get(11)?,
95 billing_type: row.get(12)?,
96 monthly_quota: row.get(13)?,
97 quota_policy: row.get(14)?,
98 priority: row.get(15)?,
99 force_protocol: row.get(16)?,
100 })
101 }
102
103 const CHANNEL_COLS: &'static str = "id, name, api_key, protocol, protocols, is_builtin, \
104 enabled, created_at, updated_at, health_status, \
105 cooldown_until, consecutive_failures, billing_type, \
106 monthly_quota, quota_policy, priority, force_protocol";
107}
108
109#[async_trait]
110impl Storage for SqliteStorage {
111 async fn list_providers(&self) -> Result<Vec<Provider>, StorageError> {
114 let pool = self.get_pool();
115 tokio::task::spawn_blocking(move || {
116 let conn = pool
117 .get()
118 .map_err(|e| StorageError::Connection(e.to_string()))?;
119 let mut stmt = conn
120 .prepare("SELECT id, name, created_at FROM providers ORDER BY id")
121 .map_err(|e| StorageError::Backend(e.to_string()))?;
122 let rows = stmt
123 .query_map([], |row| {
124 Ok(Provider {
125 id: row.get(0)?,
126 name: row.get(1)?,
127 created_at: row.get::<_, i64>(2).map_or_else(
128 |_| String::new(),
129 |ts| {
130 chrono::DateTime::from_timestamp(ts, 0)
131 .unwrap_or_default()
132 .to_rfc3339()
133 },
134 ),
135 })
136 })
137 .map_err(|e| StorageError::Backend(e.to_string()))?;
138 let mut providers = Vec::new();
139 for row in rows {
140 providers.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
141 }
142 Ok(providers)
143 })
144 .await
145 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
146 }
147
148 async fn get_provider(&self, id: &str) -> Result<Option<Provider>, StorageError> {
149 let id = id.to_string();
150 let pool = self.get_pool();
151 tokio::task::spawn_blocking(move || {
152 let conn = pool
153 .get()
154 .map_err(|e| StorageError::Connection(e.to_string()))?;
155 let mut stmt = conn
156 .prepare("SELECT id, name, created_at FROM providers WHERE id = ?1")
157 .map_err(|e| StorageError::Backend(e.to_string()))?;
158 let mut rows = stmt
159 .query_map(params![id], |row| {
160 Ok(Provider {
161 id: row.get(0)?,
162 name: row.get(1)?,
163 created_at: row.get::<_, i64>(2).map_or_else(
164 |_| String::new(),
165 |ts| {
166 chrono::DateTime::from_timestamp(ts, 0)
167 .unwrap_or_default()
168 .to_rfc3339()
169 },
170 ),
171 })
172 })
173 .map_err(|e| StorageError::Backend(e.to_string()))?;
174 match rows.next() {
175 Some(Ok(p)) => Ok(Some(p)),
176 Some(Err(e)) => Err(StorageError::Backend(e.to_string())),
177 None => Ok(None),
178 }
179 })
180 .await
181 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
182 }
183
184 async fn list_models(&self, provider_id: Option<&str>) -> Result<Vec<Model>, StorageError> {
187 let provider_id = provider_id.map(String::from);
188 let pool = self.get_pool();
189 tokio::task::spawn_blocking(move || {
190 let conn = pool
191 .get()
192 .map_err(|e| StorageError::Connection(e.to_string()))?;
193 let (sql, param_values): (&str, Vec<String>) = match &provider_id {
194 Some(pid) => (
195 "SELECT m.id, m.provider_id, m.client_name, m.price_input, m.price_output, \
196 m.currency, m.context_window, m.created_at, \
197 COALESCE((SELECT COUNT(*) FROM model_mappings WHERE client_name = m.client_name), 0) as channel_count \
198 FROM models m WHERE m.provider_id = ?1 ORDER BY m.client_name",
199 vec![pid.clone()],
200 ),
201 None => (
202 "SELECT m.id, m.provider_id, m.client_name, m.price_input, m.price_output, \
203 m.currency, m.context_window, m.created_at, \
204 COALESCE((SELECT COUNT(*) FROM model_mappings WHERE client_name = m.client_name), 0) as channel_count \
205 FROM models m ORDER BY m.provider_id, m.client_name",
206 vec![],
207 ),
208 };
209 let mut stmt = conn
210 .prepare(sql)
211 .map_err(|e| StorageError::Backend(e.to_string()))?;
212 let params_refs: Vec<&dyn rusqlite::types::ToSql> = param_values
213 .iter()
214 .map(|s| s as &dyn rusqlite::types::ToSql)
215 .collect();
216 let rows = stmt
217 .query_map(params_refs.as_slice(), |row| {
218 Ok(Model {
219 id: row.get(0)?,
220 provider_id: row.get(1)?,
221 client_name: row.get(2)?,
222 price_input: row.get(3)?,
223 price_output: row.get(4)?,
224 currency: row.get(5)?,
225 context_window: row.get(6)?,
226 created_at: row.get::<_, i64>(7).map(|ts| {
227 chrono::DateTime::from_timestamp(ts, 0)
228 .unwrap_or_default()
229 .to_rfc3339()
230 }).unwrap_or_default(),
231 channel_count: row.get(8)?,
232 })
233 })
234 .map_err(|e| StorageError::Backend(e.to_string()))?;
235 let mut models = Vec::new();
236 for row in rows {
237 models.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
238 }
239 Ok(models)
240 })
241 .await
242 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
243 }
244
245 async fn get_model(&self, id: &str) -> Result<Option<Model>, StorageError> {
246 let id = id.to_string();
247 let pool = self.get_pool();
248 tokio::task::spawn_blocking(move || {
249 let conn = pool
250 .get()
251 .map_err(|e| StorageError::Connection(e.to_string()))?;
252 let mut stmt = conn
253 .prepare(
254 "SELECT m.id, m.provider_id, m.client_name, m.price_input, m.price_output, \
255 m.currency, m.context_window, m.created_at, \
256 COALESCE((SELECT COUNT(*) FROM model_mappings WHERE client_name = m.client_name), 0) \
257 FROM models m WHERE m.id = ?1",
258 )
259 .map_err(|e| StorageError::Backend(e.to_string()))?;
260 let mut rows = stmt
261 .query_map(params![id], |row| {
262 Ok(Model {
263 id: row.get(0)?,
264 provider_id: row.get(1)?,
265 client_name: row.get(2)?,
266 price_input: row.get(3)?,
267 price_output: row.get(4)?,
268 currency: row.get(5)?,
269 context_window: row.get(6)?,
270 created_at: row.get::<_, i64>(7).map(|ts| {
271 chrono::DateTime::from_timestamp(ts, 0)
272 .unwrap_or_default()
273 .to_rfc3339()
274 }).unwrap_or_default(),
275 channel_count: row.get(8)?,
276 })
277 })
278 .map_err(|e| StorageError::Backend(e.to_string()))?;
279 match rows.next() {
280 Some(Ok(m)) => Ok(Some(m)),
281 Some(Err(e)) => Err(StorageError::Backend(e.to_string())),
282 None => Ok(None),
283 }
284 })
285 .await
286 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
287 }
288
289 async fn list_channels(&self, model_id: Option<&str>) -> Result<Vec<Channel>, StorageError> {
292 let model_id = model_id.map(String::from);
293 let pool = self.get_pool();
294 tokio::task::spawn_blocking(move || {
295 let conn = pool
296 .get()
297 .map_err(|e| StorageError::Connection(e.to_string()))?;
298 let (sql, params_vec) = match &model_id {
299 Some(mid) => (
300 format!(
301 "SELECT {} FROM channels WHERE id IN (SELECT channel_id FROM \
302 model_mappings WHERE client_name = ?1) ORDER BY id",
303 SqliteStorage::CHANNEL_COLS
304 ),
305 vec![mid.clone()],
306 ),
307 None => (
308 format!(
309 "SELECT {} FROM channels ORDER BY priority, id",
310 SqliteStorage::CHANNEL_COLS
311 ),
312 vec![],
313 ),
314 };
315 let mut stmt = conn
316 .prepare(&sql)
317 .map_err(|e| StorageError::Backend(e.to_string()))?;
318 let params_refs: Vec<&dyn rusqlite::types::ToSql> = params_vec
319 .iter()
320 .map(|s| s as &dyn rusqlite::types::ToSql)
321 .collect();
322 let rows = stmt
323 .query_map(params_refs.as_slice(), SqliteStorage::row_to_channel)
324 .map_err(|e| StorageError::Backend(e.to_string()))?;
325 let mut channels = Vec::new();
326 for row in rows {
327 channels.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
328 }
329 Ok(channels)
330 })
331 .await
332 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
333 }
334
335 async fn get_channel(&self, id: &str) -> Result<Option<Channel>, StorageError> {
336 let id = id.to_string();
337 let pool = self.get_pool();
338 tokio::task::spawn_blocking(move || {
339 let conn = pool
340 .get()
341 .map_err(|e| StorageError::Connection(e.to_string()))?;
342 let sql = format!(
343 "SELECT {} FROM channels WHERE id = ?1",
344 SqliteStorage::CHANNEL_COLS
345 );
346 let mut stmt = conn
347 .prepare(&sql)
348 .map_err(|e| StorageError::Backend(e.to_string()))?;
349 let mut rows = stmt
350 .query_map(params![id], SqliteStorage::row_to_channel)
351 .map_err(|e| StorageError::Backend(e.to_string()))?;
352 match rows.next() {
353 Some(Ok(ch)) => Ok(Some(ch)),
354 Some(Err(e)) => Err(StorageError::Backend(e.to_string())),
355 None => Ok(None),
356 }
357 })
358 .await
359 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
360 }
361
362 async fn upsert_channel(&self, channel: &Channel) -> Result<(), StorageError> {
363 let id = channel.id.clone();
364 let name = channel.name.clone();
365 let api_key = channel.api_key.expose_secret().to_string();
366 let protocol = channel.protocol.clone();
367 let protocols = channel.protocols.clone();
368 let is_builtin = channel.is_builtin;
369 let enabled = channel.enabled;
370 let now = Self::now_unix();
371 let health_status = channel.health_status.clone();
372 let billing_type = channel.billing_type.clone();
373 let monthly_quota = channel.monthly_quota;
374 let quota_policy = channel.quota_policy.clone();
375 let priority = channel.priority;
376 let force_protocol = channel.force_protocol.clone();
377 let pool = self.get_pool();
378
379 tokio::task::spawn_blocking(move || {
380 let conn = pool
381 .get()
382 .map_err(|e| StorageError::Connection(e.to_string()))?;
383 conn.execute(
384 "INSERT INTO channels (id, name, api_key, protocol, protocols, is_builtin, \
385 enabled, created_at, updated_at, health_status, billing_type, \
386 monthly_quota, quota_policy, priority, force_protocol)
387 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15)
388 ON CONFLICT(id) DO UPDATE SET
389 name = excluded.name,
390 api_key = excluded.api_key,
391 protocol = excluded.protocol,
392 protocols = excluded.protocols,
393 is_builtin = excluded.is_builtin,
394 enabled = excluded.enabled,
395 updated_at = excluded.updated_at,
396 health_status = excluded.health_status,
397 billing_type = excluded.billing_type,
398 monthly_quota = excluded.monthly_quota,
399 quota_policy = excluded.quota_policy,
400 priority = excluded.priority,
401 force_protocol = excluded.force_protocol",
402 params![
403 id,
404 name,
405 api_key,
406 protocol,
407 protocols,
408 is_builtin,
409 enabled,
410 now,
411 now,
412 health_status,
413 billing_type,
414 monthly_quota,
415 quota_policy,
416 priority,
417 force_protocol,
418 ],
419 )
420 .map_err(|e| StorageError::Backend(e.to_string()))?;
421 Ok(())
422 })
423 .await
424 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
425 }
426
427 async fn set_channel_enabled(&self, id: &str, enabled: bool) -> Result<(), StorageError> {
428 let id = id.to_string();
429 let now = Self::now_unix();
430 let pool = self.get_pool();
431
432 tokio::task::spawn_blocking(move || {
433 let conn = pool
434 .get()
435 .map_err(|e| StorageError::Connection(e.to_string()))?;
436 let rows = conn
437 .execute(
438 "UPDATE channels SET enabled = ?1, updated_at = ?2 WHERE id = ?3",
439 params![enabled, now, id],
440 )
441 .map_err(|e| StorageError::Backend(e.to_string()))?;
442 if rows == 0 {
443 return Err(StorageError::NotFound(format!("channel not found: {id}")));
444 }
445 Ok(())
446 })
447 .await
448 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
449 }
450
451 async fn set_channel_api_key(&self, id: &str, key: &SecretString) -> Result<(), StorageError> {
452 let id = id.to_string();
453 let api_key = key.expose_secret().to_string();
454 let now = Self::now_unix();
455 let pool = self.get_pool();
456
457 tokio::task::spawn_blocking(move || {
458 let conn = pool
459 .get()
460 .map_err(|e| StorageError::Connection(e.to_string()))?;
461 let rows = conn
462 .execute(
463 "UPDATE channels SET api_key = ?1, updated_at = ?2 WHERE id = ?3",
464 params![api_key, now, id],
465 )
466 .map_err(|e| StorageError::Backend(e.to_string()))?;
467 if rows == 0 {
468 return Err(StorageError::NotFound(format!("channel not found: {id}")));
469 }
470 Ok(())
471 })
472 .await
473 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
474 }
475
476 #[allow(clippy::too_many_arguments)]
477 async fn update_channel(
478 &self,
479 id: &str,
480 name: Option<&str>,
481 enabled: Option<bool>,
482 priority: Option<u32>,
483 monthly_quota: Option<u64>,
484 quota_policy: Option<&str>,
485 protocols: Option<&str>,
486 force_protocol: Option<&str>,
487 ) -> Result<Channel, StorageError> {
488 let id = id.to_string();
489 let name = name.map(String::from);
490 let quota_policy = quota_policy.map(String::from);
491 let protocols = protocols.map(String::from);
492 let force_protocol = force_protocol.map(String::from);
493 let now = Self::now_unix();
494 let pool = self.get_pool();
495
496 tokio::task::spawn_blocking(move || {
497 let conn = pool
498 .get()
499 .map_err(|e| StorageError::Connection(e.to_string()))?;
500
501 let mut sets = vec!["updated_at = ?1".to_string()];
503 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(now)];
504
505 if let Some(ref n) = name {
506 sets.push(format!("name = ?{}", param_values.len() + 1));
507 param_values.push(Box::new(n.clone()));
508 }
509 if let Some(e) = enabled {
510 sets.push(format!("enabled = ?{}", param_values.len() + 1));
511 param_values.push(Box::new(e));
512 }
513 if let Some(p) = priority {
514 sets.push(format!("priority = ?{}", param_values.len() + 1));
515 param_values.push(Box::new(p));
516 }
517 if let Some(q) = monthly_quota {
518 sets.push(format!("monthly_quota = ?{}", param_values.len() + 1));
519 param_values.push(Box::new(i64::try_from(q).unwrap_or(i64::MAX)));
520 }
521 if let Some(ref qp) = quota_policy {
522 sets.push(format!("quota_policy = ?{}", param_values.len() + 1));
523 param_values.push(Box::new(qp.clone()));
524 }
525 if let Some(ref p) = protocols {
526 sets.push(format!("protocols = ?{}", param_values.len() + 1));
527 param_values.push(Box::new(p.clone()));
528 }
529 if let Some(ref fp) = force_protocol {
530 sets.push(format!("force_protocol = ?{}", param_values.len() + 1));
531 param_values.push(Box::new(fp.clone()));
532 }
533
534 let id_param_idx = param_values.len() + 1;
535 param_values.push(Box::new(id.clone()));
536
537 let sql = format!(
538 "UPDATE channels SET {} WHERE id = ?{id_param_idx}",
539 sets.join(", "),
540 );
541
542 let params_refs: Vec<&dyn rusqlite::types::ToSql> =
543 param_values.iter().map(AsRef::as_ref).collect();
544
545 let rows = conn
546 .execute(&sql, params_refs.as_slice())
547 .map_err(|e| StorageError::Backend(e.to_string()))?;
548 if rows == 0 {
549 return Err(StorageError::NotFound(format!("channel not found: {id}")));
550 }
551
552 let channel_sql = format!(
554 "SELECT {} FROM channels WHERE id = ?1",
555 SqliteStorage::CHANNEL_COLS
556 );
557 let mut stmt = conn
558 .prepare(&channel_sql)
559 .map_err(|e| StorageError::Backend(e.to_string()))?;
560 let updated = stmt
561 .query_row(params![id], SqliteStorage::row_to_channel)
562 .map_err(|e| StorageError::Backend(e.to_string()))?;
563 Ok(updated)
564 })
565 .await
566 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
567 }
568
569 async fn delete_channel(&self, id: &str) -> Result<(), StorageError> {
570 let id = id.to_string();
571 let pool = self.get_pool();
572
573 tokio::task::spawn_blocking(move || {
574 let conn = pool
575 .get()
576 .map_err(|e| StorageError::Connection(e.to_string()))?;
577 let rows = conn
578 .execute("DELETE FROM channels WHERE id = ?1", params![id])
579 .map_err(|e| StorageError::Backend(e.to_string()))?;
580 if rows == 0 {
581 return Err(StorageError::NotFound(format!("channel not found: {id}")));
582 }
583 Ok(())
584 })
585 .await
586 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
587 }
588
589 async fn mark_channel_healthy(&self, id: &str) -> Result<(), StorageError> {
590 let id = id.to_string();
591 let now = Self::now_unix();
592 let pool = self.get_pool();
593
594 tokio::task::spawn_blocking(move || {
595 let conn = pool
596 .get()
597 .map_err(|e| StorageError::Connection(e.to_string()))?;
598 let rows = conn
599 .execute(
600 "UPDATE channels SET health_status = 'Healthy', cooldown_until = NULL, \
601 consecutive_failures = 0, updated_at = ?1 WHERE id = ?2",
602 params![now, id],
603 )
604 .map_err(|e| StorageError::Backend(e.to_string()))?;
605 if rows == 0 {
606 return Err(StorageError::NotFound(format!("channel not found: {id}")));
607 }
608 Ok(())
609 })
610 .await
611 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
612 }
613
614 async fn record_channel_failure(&self, id: &str) -> Result<(), StorageError> {
615 let id = id.to_string();
616 let now = Self::now_unix();
617 let pool = self.get_pool();
618
619 tokio::task::spawn_blocking(move || {
620 let conn = pool
621 .get()
622 .map_err(|e| StorageError::Connection(e.to_string()))?;
623
624 conn.execute(
626 "UPDATE channels SET
627 consecutive_failures = consecutive_failures + 1,
628 updated_at = ?1
629 WHERE id = ?2",
630 params![now, id],
631 )
632 .map_err(|e| StorageError::Backend(e.to_string()))?;
633
634 let failures: i32 = conn
636 .query_row(
637 "SELECT consecutive_failures FROM channels WHERE id = ?1",
638 params![id],
639 |row| row.get(0),
640 )
641 .map_err(|e| StorageError::Backend(e.to_string()))?;
642
643 let status = if failures >= 3 {
644 "Cooldown"
645 } else if failures >= 1 {
646 "Degraded"
647 } else {
648 "Healthy"
649 };
650
651 let cooldown_sql = if status == "Cooldown" {
652 format!(
653 ", cooldown_until = '{}'",
654 chrono::Utc::now()
655 .checked_add_signed(chrono::Duration::minutes(5))
656 .unwrap_or(chrono::Utc::now())
657 .to_rfc3339()
658 )
659 } else {
660 String::new()
661 };
662
663 conn.execute(
664 &format!("UPDATE channels SET health_status = ?1 {cooldown_sql} WHERE id = ?2"),
665 params![status, id],
666 )
667 .map_err(|e| StorageError::Backend(e.to_string()))?;
668
669 Ok(())
670 })
671 .await
672 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
673 }
674
675 async fn list_mappings(&self, channel_id: &str) -> Result<Vec<ModelMapping>, StorageError> {
678 let channel_id = channel_id.to_string();
679 let pool = self.get_pool();
680
681 tokio::task::spawn_blocking(move || {
682 let conn = pool
683 .get()
684 .map_err(|e| StorageError::Connection(e.to_string()))?;
685 let mut stmt = conn
686 .prepare(
687 "SELECT id, channel_id, client_name, upstream_name, billing, pricing_json, \
688 weight, enabled, protocols
689 FROM model_mappings WHERE channel_id = ?1 ORDER BY id",
690 )
691 .map_err(|e| StorageError::Backend(e.to_string()))?;
692 let rows = stmt
693 .query_map(params![channel_id], |row| {
694 Ok(ModelMapping {
695 id: row.get(0)?,
696 channel_id: row.get(1)?,
697 client_name: row.get(2)?,
698 upstream_name: row.get(3)?,
699 billing: row.get(4)?,
700 pricing_json: row.get(5)?,
701 weight: row.get(6)?,
702 enabled: row.get(7)?,
703 protocols: row.get(8)?,
704 })
705 })
706 .map_err(|e| StorageError::Backend(e.to_string()))?;
707 let mut mappings = Vec::new();
708 for row in rows {
709 mappings.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
710 }
711 Ok(mappings)
712 })
713 .await
714 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
715 }
716
717 async fn upsert_mapping(&self, mapping: &ModelMapping) -> Result<(), StorageError> {
718 let id = mapping.id.clone();
719 let channel_id = mapping.channel_id.clone();
720 let client_name = mapping.client_name.clone();
721 let upstream_name = mapping.upstream_name.clone();
722 let billing = mapping.billing.clone();
723 let pricing_json = mapping.pricing_json.clone();
724 let weight = mapping.weight;
725 let enabled = mapping.enabled;
726 let protocols = mapping.protocols.clone();
727 let pool = self.get_pool();
728
729 tokio::task::spawn_blocking(move || {
730 let conn = pool
731 .get()
732 .map_err(|e| StorageError::Connection(e.to_string()))?;
733 conn.execute(
734 "INSERT INTO model_mappings (id, channel_id, client_name, upstream_name, billing, \
735 pricing_json, weight, enabled, protocols)
736 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
737 ON CONFLICT(id) DO UPDATE SET
738 channel_id = excluded.channel_id,
739 client_name = excluded.client_name,
740 upstream_name = excluded.upstream_name,
741 billing = excluded.billing,
742 pricing_json = excluded.pricing_json,
743 weight = excluded.weight,
744 enabled = excluded.enabled,
745 protocols = excluded.protocols",
746 params![
747 id,
748 channel_id,
749 client_name,
750 upstream_name,
751 billing,
752 pricing_json,
753 weight,
754 enabled,
755 protocols,
756 ],
757 )
758 .map_err(|e| {
759 let msg = e.to_string();
760 if msg.contains("FOREIGN KEY") {
761 StorageError::NotFound(format!("channel not found: {channel_id}"))
762 } else {
763 StorageError::Backend(msg)
764 }
765 })?;
766 Ok(())
767 })
768 .await
769 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
770 }
771
772 async fn set_mapping_enabled(&self, id: &str, enabled: bool) -> Result<(), StorageError> {
773 let id = id.to_string();
774 let pool = self.get_pool();
775
776 tokio::task::spawn_blocking(move || {
777 let conn = pool
778 .get()
779 .map_err(|e| StorageError::Connection(e.to_string()))?;
780 let rows = conn
781 .execute(
782 "UPDATE model_mappings SET enabled = ?1 WHERE id = ?2",
783 params![enabled, id],
784 )
785 .map_err(|e| StorageError::Backend(e.to_string()))?;
786 if rows == 0 {
787 return Err(StorageError::NotFound(format!("mapping not found: {id}")));
788 }
789 Ok(())
790 })
791 .await
792 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
793 }
794
795 async fn delete_mapping(&self, id: &str) -> Result<(), StorageError> {
796 let id = id.to_string();
797 let pool = self.get_pool();
798
799 tokio::task::spawn_blocking(move || {
800 let conn = pool
801 .get()
802 .map_err(|e| StorageError::Connection(e.to_string()))?;
803 let rows = conn
804 .execute("DELETE FROM model_mappings WHERE id = ?1", params![id])
805 .map_err(|e| StorageError::Backend(e.to_string()))?;
806 if rows == 0 {
807 return Err(StorageError::NotFound(format!("mapping not found: {id}")));
808 }
809 Ok(())
810 })
811 .await
812 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
813 }
814
815 async fn list_all_mappings(&self) -> Result<Vec<ModelMapping>, StorageError> {
816 let pool = self.get_pool();
817 tokio::task::spawn_blocking(move || {
818 let conn = pool.get().map_err(|e| StorageError::Connection(e.to_string()))?;
819 let mut stmt = conn
820 .prepare("SELECT id, channel_id, client_name, upstream_name, billing, pricing_json, weight, enabled, protocols FROM model_mappings ORDER BY channel_id, client_name")
821 .map_err(|e| StorageError::Backend(e.to_string()))?;
822 let rows = stmt
823 .query_map([], |row| {
824 Ok(ModelMapping {
825 id: row.get(0)?,
826 channel_id: row.get(1)?,
827 client_name: row.get(2)?,
828 upstream_name: row.get(3)?,
829 billing: row.get(4)?,
830 pricing_json: row.get(5)?,
831 weight: row.get(6)?,
832 enabled: row.get(7)?,
833 protocols: row.get(8)?,
834 })
835 })
836 .map_err(|e| StorageError::Backend(e.to_string()))?;
837 let mut mappings = Vec::new();
838 for row in rows {
839 mappings.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
840 }
841 Ok(mappings)
842 })
843 .await
844 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
845 }
846
847 async fn insert_cost_record(&self, record: &CostRecord) -> Result<(), StorageError> {
850 let id = record.id.clone();
851 let channel_id = record.channel_id.clone();
852 let upstream_channel = record.upstream_channel.clone();
853 let upstream_model = record.upstream_model.clone();
854 let request_time_ms = record.request_time_ms;
855 let project = record.project.clone();
856 let user_id = record.user_id.clone();
857 let agent_type = record.agent_type.clone();
858 let input_tokens = record.input_tokens;
859 let output_tokens = record.output_tokens;
860 let cache_write_tokens = record.cache_write_tokens;
861 let cache_read_tokens = record.cache_read_tokens;
862 let thinking_tokens = record.thinking_tokens;
863 let cost = record.cost;
864 let schema_saved_tokens = record.schema_saved_tokens;
865 let response_saved_tokens = record.response_saved_tokens;
866 let rtk_saved_tokens = record.rtk_saved_tokens;
867 let pre_compress_tokens = record.pre_compress_tokens;
868 let post_compress_tokens = record.post_compress_tokens;
869 let compression_tokens_saved = record.compression_tokens_saved;
870 let pricing_snapshot_json = record.pricing_snapshot_json.clone();
871 let unit = record.unit.clone();
872 let timestamp = record.timestamp.clone();
873 let session_id = record.session_id.clone();
874 let before_tokens = record.before_tokens;
875 let after_tokens = record.after_tokens;
876 let tokens_saved = record.tokens_saved;
877 let compression_breakdown = record.compression_breakdown_json.clone();
878 let pool = self.get_pool();
879
880 tokio::task::spawn_blocking(move || {
881 let conn = pool
882 .get()
883 .map_err(|e| StorageError::Connection(e.to_string()))?;
884 conn.execute(
885 "INSERT INTO cost_records
886 (id, channel_id, upstream_channel, upstream_model, request_time_ms, project, user_id, agent_type,
887 input_tokens, output_tokens, cache_write_tokens, cache_read_tokens,
888 thinking_tokens, cost,
889 schema_saved_tokens, response_saved_tokens, rtk_saved_tokens,
890 pre_compress_tokens, post_compress_tokens, compression_tokens_saved,
891 unit, pricing_snapshot_json, timestamp,
892 session_id, before_tokens, after_tokens, tokens_saved, compression_breakdown_json)
893 VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13,?14,?15,?16,?17,?18,?19,?20,?21,?22,?23,
894 ?24,?25,?26,?27,?28)",
895 params![
896 id,
897 channel_id,
898 upstream_channel,
899 upstream_model,
900 request_time_ms,
901 project,
902 user_id,
903 agent_type,
904 input_tokens,
905 output_tokens,
906 cache_write_tokens,
907 cache_read_tokens,
908 thinking_tokens,
909 cost,
910 schema_saved_tokens,
911 response_saved_tokens,
912 rtk_saved_tokens,
913 pre_compress_tokens,
914 post_compress_tokens,
915 compression_tokens_saved,
916 unit,
917 pricing_snapshot_json,
918 timestamp,
919 session_id,
920 before_tokens,
921 after_tokens,
922 tokens_saved,
923 compression_breakdown,
924 ],
925 )
926 .map_err(|e| StorageError::Backend(e.to_string()))?;
927 Ok(())
928 })
929 .await
930 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
931 }
932
933 async fn query_cost_records(
934 &self,
935 filter: CostFilter,
936 ) -> Result<Vec<CostRecord>, StorageError> {
937 let pool = self.get_pool();
938
939 tokio::task::spawn_blocking(move || {
940 let conn = pool
941 .get()
942 .map_err(|e| StorageError::Connection(e.to_string()))?;
943
944 let mut sql = String::from(
945 "SELECT id, channel_id, upstream_channel, upstream_model, request_time_ms, project, user_id, agent_type,
946 input_tokens, output_tokens, cache_write_tokens, cache_read_tokens,
947 thinking_tokens, cost,
948 schema_saved_tokens, response_saved_tokens, rtk_saved_tokens,
949 pre_compress_tokens, post_compress_tokens, compression_tokens_saved,
950 unit, pricing_snapshot_json, timestamp,
951 session_id, before_tokens, after_tokens, tokens_saved, compression_breakdown_json
952 FROM cost_records WHERE 1=1",
953 );
954 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
955
956 if let Some(project_path) = filter.project_path {
957 sql.push_str(" AND project = ?");
958 param_values.push(Box::new(project_path));
959 }
960 if let Some(model_name) = filter.model_name {
961 sql.push_str(" AND upstream_model = ?");
962 param_values.push(Box::new(model_name));
963 }
964 if let Some(channel_name) = filter.channel_name {
965 sql.push_str(" AND channel_id = ?");
966 param_values.push(Box::new(channel_name));
967 }
968 if let Some(ref tr) = filter.time_range {
969 sql.push_str(" AND timestamp >= ? AND timestamp < ?");
970 let start_rfc = chrono::DateTime::from_timestamp(tr.start, 0)
971 .unwrap_or_default()
972 .to_rfc3339();
973 let end_rfc = chrono::DateTime::from_timestamp(tr.end, 0)
974 .unwrap_or_default()
975 .to_rfc3339();
976 param_values.push(Box::new(start_rfc));
977 param_values.push(Box::new(end_rfc));
978 }
979
980 sql.push_str(" ORDER BY timestamp DESC");
981
982 let limit = filter.limit.unwrap_or(100);
983 let offset = filter.offset.unwrap_or(0);
984 let _ = write!(sql, " LIMIT {limit} OFFSET {offset}");
985
986 let params_refs: Vec<&dyn rusqlite::types::ToSql> = param_values
987 .iter()
988 .map(std::convert::AsRef::as_ref)
989 .collect();
990
991 let mut stmt = conn
992 .prepare(&sql)
993 .map_err(|e| StorageError::Backend(e.to_string()))?;
994 let rows = stmt
995 .query_map(params_refs.as_slice(), |row| {
996 Ok(CostRecord {
997 id: row.get(0)?,
998 channel_id: row.get(1)?,
999 upstream_channel: row.get(2)?,
1000 upstream_model: row.get(3)?,
1001 request_time_ms: row.get(4)?,
1002 project: row.get(5)?,
1003 user_id: row.get(6)?,
1004 agent_type: row.get(7)?,
1005 input_tokens: row.get(8)?,
1006 output_tokens: row.get(9)?,
1007 cache_write_tokens: row.get(10)?,
1008 cache_read_tokens: row.get(11)?,
1009 thinking_tokens: row.get(12)?,
1010 cost: row.get(13)?,
1011 schema_saved_tokens: row.get(14)?,
1012 response_saved_tokens: row.get(15)?,
1013 rtk_saved_tokens: row.get(16)?,
1014 pre_compress_tokens: row.get(17)?,
1015 post_compress_tokens: row.get(18)?,
1016 compression_tokens_saved: row.get(19)?,
1017 unit: row.get(20)?,
1018 pricing_snapshot_json: row.get(21)?,
1019 timestamp: row.get(22)?,
1020 session_id: row.get(23)?,
1021 before_tokens: row.get(24)?,
1022 after_tokens: row.get(25)?,
1023 tokens_saved: row.get(26)?,
1024 compression_breakdown_json: row.get(27)?,
1025 })
1026 })
1027 .map_err(|e| StorageError::Backend(e.to_string()))?;
1028
1029 let mut records = Vec::new();
1030 for row in rows {
1031 records.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
1032 }
1033 Ok(records)
1034 })
1035 .await
1036 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1037 }
1038
1039 async fn aggregate_costs(
1040 &self,
1041 group_by: CostGroupBy,
1042 range: TimeRange,
1043 ) -> Result<Vec<CostAggregate>, StorageError> {
1044 let pool = self.get_pool();
1045 let start_rfc = chrono::DateTime::from_timestamp(range.start, 0)
1048 .unwrap_or_default()
1049 .to_rfc3339();
1050 let end_rfc = chrono::DateTime::from_timestamp(range.end, 0)
1051 .unwrap_or_default()
1052 .to_rfc3339();
1053
1054 tokio::task::spawn_blocking(move || {
1055 let conn = pool
1056 .get()
1057 .map_err(|e| StorageError::Connection(e.to_string()))?;
1058
1059 let (group_key_expr, group_clause): (&str, &str) = match group_by {
1060 CostGroupBy::Project => ("project", "project"),
1061 CostGroupBy::Model | CostGroupBy::Channel => ("channel_id", "channel_id"),
1062 CostGroupBy::ProjectModelMonth => (
1063 "project || '|' || upstream_model || '|' || substr(timestamp, 1, 7)",
1064 "project, upstream_model",
1065 ),
1066 CostGroupBy::ProjectModelHour => (
1067 "project || '|' || upstream_model || '|' || substr(timestamp, 1, 13)",
1068 "project, upstream_model",
1069 ),
1070 CostGroupBy::Hourly => (
1071 "substr(timestamp, 1, 13)",
1072 "substr(timestamp, 1, 13)",
1073 ),
1074 CostGroupBy::Daily => (
1075 "substr(timestamp, 1, 10)",
1076 "substr(timestamp, 1, 10)",
1077 ),
1078 };
1079
1080 let project_filter = range
1082 .project
1083 .as_ref()
1084 .map(|_| " AND project = ?3")
1085 .unwrap_or("");
1086 let sql = format!(
1087 "SELECT {group_key_expr} as group_key,
1088 SUM(input_tokens) as total_input_tokens,
1089 SUM(output_tokens) as total_output_tokens,
1090 SUM(cost) as total_actual_cost,
1091 SUM(compression_tokens_saved) as total_compression_tokens_saved,
1092 COUNT(*) as request_count
1093 FROM cost_records
1094 WHERE timestamp >= ?1 AND timestamp < ?2{project_filter}
1095 GROUP BY {group_clause}
1096 ORDER BY total_actual_cost DESC"
1097 );
1098
1099 let mut stmt = conn
1100 .prepare(&sql)
1101 .map_err(|e| StorageError::Backend(e.to_string()))?;
1102
1103 let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![
1104 Box::new(start_rfc),
1105 Box::new(end_rfc),
1106 ];
1107 if let Some(ref proj) = range.project {
1108 params.push(Box::new(proj.clone()));
1109 }
1110
1111 let rows = stmt
1112 .query_map(
1113 rusqlite::params_from_iter(params.iter().map(|p| p.as_ref())),
1114 |row| {
1115 Ok(CostAggregate {
1116 group_key: row.get(0)?,
1117 total_input_tokens: row.get(1)?,
1118 total_output_tokens: row.get(2)?,
1119 total_actual_cost: row.get(3)?,
1120 total_compression_tokens_saved: row.get(4)?,
1121 request_count: row.get(5)?,
1122 })
1123 },
1124 )
1125 .map_err(|e| StorageError::Backend(e.to_string()))?;
1126
1127 let mut results = Vec::new();
1128 for row in rows {
1129 results.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
1130 }
1131 Ok(results)
1132 })
1133 .await
1134 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1135 }
1136
1137 async fn prune_cost_records(&self, older_than_days: u32) -> Result<u64, StorageError> {
1138 let pool = self.get_pool();
1139
1140 tokio::task::spawn_blocking(move || {
1141 let conn = pool
1142 .get()
1143 .map_err(|e| StorageError::Connection(e.to_string()))?;
1144
1145 let cutoff = chrono::Utc::now()
1147 .checked_sub_signed(chrono::Duration::days(i64::from(older_than_days)))
1148 .unwrap_or(chrono::Utc::now())
1149 .to_rfc3339();
1150
1151 let deleted = conn
1152 .execute(
1153 "DELETE FROM cost_records WHERE timestamp < ?1",
1154 params![cutoff],
1155 )
1156 .map_err(|e| StorageError::Backend(e.to_string()))?;
1157 Ok(deleted as u64)
1158 })
1159 .await
1160 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1161 }
1162
1163 async fn list_projects(&self) -> Result<Vec<String>, StorageError> {
1164 let pool = self.get_pool();
1165
1166 tokio::task::spawn_blocking(move || {
1167 let conn = pool
1168 .get()
1169 .map_err(|e| StorageError::Connection(e.to_string()))?;
1170
1171 let mut stmt = conn
1172 .prepare("SELECT DISTINCT project FROM cost_records ORDER BY project")
1173 .map_err(|e| StorageError::Backend(e.to_string()))?;
1174
1175 let projects: Vec<String> = stmt
1176 .query_map([], |row| row.get(0))
1177 .map_err(|e| StorageError::Backend(e.to_string()))?
1178 .filter_map(|r| r.ok())
1179 .collect();
1180
1181 Ok(projects)
1182 })
1183 .await
1184 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1185 }
1186
1187 async fn insert_switch_log(&self, log: &SwitchLog) -> Result<(), StorageError> {
1190 let id = log.id.clone();
1191 let from_channel_id = log.from_channel_id.clone();
1192 let to_channel_id = log.to_channel_id.clone();
1193 let reason = log.reason.clone();
1194 let cost_record_id = log.cost_record_id.clone();
1195 let created_at = log.created_at.clone();
1196 let pool = self.get_pool();
1197
1198 tokio::task::spawn_blocking(move || {
1199 let conn = pool
1200 .get()
1201 .map_err(|e| StorageError::Connection(e.to_string()))?;
1202 conn.execute(
1203 "INSERT INTO switch_logs (id, from_channel_id, to_channel_id, reason, \
1204 cost_record_id, created_at)
1205 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
1206 params![
1207 id,
1208 from_channel_id,
1209 to_channel_id,
1210 reason,
1211 cost_record_id,
1212 created_at
1213 ],
1214 )
1215 .map_err(|e| StorageError::Backend(e.to_string()))?;
1216 Ok(())
1217 })
1218 .await
1219 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1220 }
1221
1222 async fn insert_subscription_fee(&self, fee: &SubscriptionFee) -> Result<(), StorageError> {
1225 let channel_name = fee.channel_name.clone();
1226 let month = fee.month.clone();
1227 let monthly_price = fee.monthly_price;
1228 let currency = fee.currency.clone();
1229 let pool = self.get_pool();
1230
1231 tokio::task::spawn_blocking(move || {
1232 let conn = pool
1233 .get()
1234 .map_err(|e| StorageError::Connection(e.to_string()))?;
1235 conn.execute(
1236 "INSERT INTO subscription_fees (channel_name, month, monthly_price, currency)
1237 VALUES (?1, ?2, ?3, ?4)",
1238 params![channel_name, month, monthly_price, currency],
1239 )
1240 .map_err(|e| {
1241 let msg = e.to_string();
1242 if msg.contains("UNIQUE constraint") {
1243 StorageError::Duplicate(msg)
1244 } else {
1245 StorageError::Backend(msg)
1246 }
1247 })?;
1248 Ok(())
1249 })
1250 .await
1251 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1252 }
1253
1254 async fn query_subscription_fees(
1255 &self,
1256 channel: Option<&str>,
1257 month: Option<&str>,
1258 ) -> Result<Vec<SubscriptionFee>, StorageError> {
1259 let channel = channel.map(String::from);
1260 let month = month.map(String::from);
1261 let pool = self.get_pool();
1262
1263 tokio::task::spawn_blocking(move || {
1264 let conn = pool
1265 .get()
1266 .map_err(|e| StorageError::Connection(e.to_string()))?;
1267
1268 let mut sql = String::from(
1269 "SELECT id, channel_name, month, monthly_price, currency FROM subscription_fees \
1270 WHERE 1=1",
1271 );
1272 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
1273
1274 if let Some(ref ch) = channel {
1275 sql.push_str(" AND channel_name = ?");
1276 param_values.push(Box::new(ch.clone()));
1277 }
1278 if let Some(ref mo) = month {
1279 sql.push_str(" AND month = ?");
1280 param_values.push(Box::new(mo.clone()));
1281 }
1282
1283 sql.push_str(" ORDER BY month DESC, channel_name");
1284
1285 let params_refs: Vec<&dyn rusqlite::types::ToSql> =
1286 param_values.iter().map(AsRef::as_ref).collect();
1287
1288 let mut stmt = conn
1289 .prepare(&sql)
1290 .map_err(|e| StorageError::Backend(e.to_string()))?;
1291 let rows = stmt
1292 .query_map(params_refs.as_slice(), |row| {
1293 Ok(SubscriptionFee {
1294 id: row.get(0)?,
1295 channel_name: row.get(1)?,
1296 month: row.get(2)?,
1297 monthly_price: row.get(3)?,
1298 currency: row.get(4)?,
1299 })
1300 })
1301 .map_err(|e| StorageError::Backend(e.to_string()))?;
1302
1303 let mut fees = Vec::new();
1304 for row in rows {
1305 fees.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
1306 }
1307 Ok(fees)
1308 })
1309 .await
1310 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1311 }
1312
1313 async fn query_switch_logs(&self, limit: Option<u32>) -> Result<Vec<SwitchLog>, StorageError> {
1316 let limit = limit.unwrap_or(20).min(100);
1317 let pool = self.get_pool();
1318
1319 tokio::task::spawn_blocking(move || {
1320 let conn = pool
1321 .get()
1322 .map_err(|e| StorageError::Connection(e.to_string()))?;
1323 let mut stmt = conn
1324 .prepare(
1325 "SELECT id, from_channel_id, to_channel_id, reason, cost_record_id, created_at
1326 FROM switch_logs ORDER BY created_at DESC LIMIT ?1",
1327 )
1328 .map_err(|e| StorageError::Backend(e.to_string()))?;
1329 let rows = stmt
1330 .query_map(params![limit], |row| {
1331 Ok(SwitchLog {
1332 id: row.get(0)?,
1333 from_channel_id: row.get(1)?,
1334 to_channel_id: row.get(2)?,
1335 reason: row.get(3)?,
1336 cost_record_id: row.get(4)?,
1337 created_at: row.get(5)?,
1338 })
1339 })
1340 .map_err(|e| StorageError::Backend(e.to_string()))?;
1341 let mut logs = Vec::new();
1342 for row in rows {
1343 logs.push(row.map_err(|e| StorageError::Backend(e.to_string()))?);
1344 }
1345 Ok(logs)
1346 })
1347 .await
1348 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1349 }
1350
1351 async fn list_available_channels(&self) -> Result<Vec<AvailableChannelInfo>, StorageError> {
1354 let pool = self.get_pool();
1355
1356 tokio::task::spawn_blocking(move || {
1357 let conn = pool
1358 .get()
1359 .map_err(|e| StorageError::Connection(e.to_string()))?;
1360
1361 let mut ch_stmt = conn
1363 .prepare(
1364 "SELECT id, name, protocol, protocols, health_status
1365 FROM channels WHERE enabled = 1 ORDER BY priority, id",
1366 )
1367 .map_err(|e| StorageError::Backend(e.to_string()))?;
1368
1369 let channels: Vec<(String, String, String, String, String)> = ch_stmt
1370 .query_map([], |row| {
1371 Ok((
1372 row.get(0)?,
1373 row.get(1)?,
1374 row.get(2)?,
1375 row.get::<_, String>(3).unwrap_or_default(),
1376 row.get::<_, String>(4).unwrap_or_default(),
1377 ))
1378 })
1379 .map_err(|e| StorageError::Backend(e.to_string()))?
1380 .flatten()
1381 .collect();
1382
1383 let mut result = Vec::new();
1384 for (ch_id, ch_name, protocol, protocols, health) in channels {
1385 let mut m_stmt = conn
1387 .prepare(
1388 "SELECT id, client_name, upstream_name
1389 FROM model_mappings WHERE channel_id = ?1 AND enabled = 1
1390 ORDER BY client_name",
1391 )
1392 .map_err(|e| StorageError::Backend(e.to_string()))?;
1393
1394 let models: Vec<AvailableModelInfo> = m_stmt
1395 .query_map(params![ch_id], |row| {
1396 Ok(AvailableModelInfo {
1397 mapping_id: row.get(0)?,
1398 client_name: row.get(1)?,
1399 upstream_name: row.get(2)?,
1400 })
1401 })
1402 .map_err(|e| StorageError::Backend(e.to_string()))?
1403 .flatten()
1404 .collect();
1405
1406 if models.is_empty() {
1408 continue;
1409 }
1410
1411 result.push(AvailableChannelInfo {
1412 channel_id: ch_id,
1413 channel_name: ch_name,
1414 protocol,
1415 protocols,
1416 health_status: health,
1417 enabled: true,
1418 models,
1419 });
1420 }
1421 Ok(result)
1422 })
1423 .await
1424 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1425 }
1426
1427 async fn migrate(&self) -> Result<(), StorageError> {
1430 let pool = self.get_pool();
1431
1432 tokio::task::spawn_blocking(move || {
1433 let conn = pool
1434 .get()
1435 .map_err(|e| StorageError::Connection(e.to_string()))?;
1436
1437 let version: i64 = conn
1438 .pragma_query_value(None, "user_version", |row| row.get(0))
1439 .unwrap_or(0);
1440
1441 if version < 1 {
1442 conn.execute_batch(MIGRATION_V1)
1443 .map_err(|e| StorageError::Migration(e.to_string()))?;
1444 }
1445
1446 conn.pragma_update(None, "user_version", 1)
1447 .map_err(|e| StorageError::Migration(e.to_string()))?;
1448
1449 Ok(())
1450 })
1451 .await
1452 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1453 }
1454
1455 async fn health_check(&self) -> Result<bool, StorageError> {
1456 let pool = self.get_pool();
1457
1458 tokio::task::spawn_blocking(move || {
1459 let conn = pool
1460 .get()
1461 .map_err(|_| StorageError::Connection("unable to get connection".into()))?;
1462 conn.execute_batch("SELECT 1")
1463 .map_err(|e| StorageError::Connection(e.to_string()))?;
1464 Ok(true)
1465 })
1466 .await
1467 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1468 }
1469
1470 fn max_connections(&self) -> usize {
1471 4
1472 }
1473}
1474
1475#[async_trait]
1478impl SeedManager for SqliteStorage {
1479 async fn seed_init(&self) -> Result<SeedStatus, StorageError> {
1480 let ops = seed::SeedOps::new(self.get_pool());
1481 tokio::task::spawn_blocking(move || ops.seed_init())
1482 .await
1483 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1484 }
1485
1486 async fn seed_refresh(&self, url: Option<&str>) -> Result<SeedStatus, StorageError> {
1487 let url = url.map(String::from);
1488 let ops = seed::SeedOps::new(self.get_pool());
1489 tokio::task::spawn_blocking(move || ops.seed_refresh(url.as_deref()))
1490 .await
1491 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1492 }
1493
1494 async fn seed_status(&self) -> Result<SeedStatus, StorageError> {
1495 let ops = seed::SeedOps::new(self.get_pool());
1496 tokio::task::spawn_blocking(move || ops.seed_status())
1497 .await
1498 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1499 }
1500
1501 async fn seed_check_remote(&self, url: Option<&str>) -> Result<SeedStatus, StorageError> {
1502 let url = url.map(String::from);
1503 let ops = seed::SeedOps::new(self.get_pool());
1504 tokio::task::spawn_blocking(move || ops.seed_check_remote(url.as_deref()))
1505 .await
1506 .map_err(|e| StorageError::Backend(format!("join error: {e}")))?
1507 }
1508}
1509
1510#[cfg(test)]
1511#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
1512mod tests {
1513 use super::*;
1514
1515 fn setup_in_memory() -> SqliteStorage {
1517 let storage = SqliteStorage::new_in_memory().expect("failed to create in-memory storage");
1518 let rt = tokio::runtime::Runtime::new().unwrap();
1519 rt.block_on(storage.migrate()).expect("migration failed");
1520 rt.block_on(storage.seed_init()).expect("seed init failed");
1521 storage
1522 }
1523
1524 async fn setup_in_memory_async() -> SqliteStorage {
1526 let storage = SqliteStorage::new_in_memory().expect("failed to create in-memory storage");
1527 storage.migrate().await.expect("migration failed");
1528 storage.seed_init().await.expect("seed init failed");
1529 storage
1530 }
1531
1532 #[test]
1535 fn test_providers_table_exists() {
1536 let storage = setup_in_memory();
1537 let pool = storage.get_pool();
1538 let conn = pool.get().unwrap();
1539
1540 let count: i64 = conn
1542 .query_row(
1543 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='providers'",
1544 [],
1545 |row| row.get(0),
1546 )
1547 .unwrap();
1548 assert_eq!(count, 1, "providers table should exist");
1549 }
1550
1551 #[test]
1552 fn test_models_table_exists() {
1553 let storage = setup_in_memory();
1554 let pool = storage.get_pool();
1555 let conn = pool.get().unwrap();
1556
1557 let count: i64 = conn
1558 .query_row(
1559 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='models'",
1560 [],
1561 |row| row.get(0),
1562 )
1563 .unwrap();
1564 assert_eq!(count, 1, "models table should exist");
1565 }
1566
1567 #[test]
1568 fn test_providers_table_has_correct_columns() {
1569 let storage = setup_in_memory();
1570 let pool = storage.get_pool();
1571 let conn = pool.get().unwrap();
1572
1573 let mut stmt = conn.prepare("PRAGMA table_info('providers')").unwrap();
1574 let columns: Vec<String> = stmt
1575 .query_map([], |row| row.get::<_, String>(1))
1576 .unwrap()
1577 .filter_map(std::result::Result::ok)
1578 .collect();
1579
1580 assert!(
1581 columns.contains(&"id".to_string()),
1582 "providers should have 'id' column"
1583 );
1584 assert!(
1585 columns.contains(&"name".to_string()),
1586 "providers should have 'name' column"
1587 );
1588 assert!(
1589 columns.contains(&"created_at".to_string()),
1590 "providers should have 'created_at' column"
1591 );
1592 }
1593
1594 #[test]
1595 fn test_models_table_has_correct_columns() {
1596 let storage = setup_in_memory();
1597 let pool = storage.get_pool();
1598 let conn = pool.get().unwrap();
1599
1600 let mut stmt = conn.prepare("PRAGMA table_info('models')").unwrap();
1601 let columns: Vec<String> = stmt
1602 .query_map([], |row| row.get::<_, String>(1))
1603 .unwrap()
1604 .filter_map(std::result::Result::ok)
1605 .collect();
1606
1607 assert!(
1608 columns.contains(&"id".to_string()),
1609 "models should have 'id' column"
1610 );
1611 assert!(
1612 columns.contains(&"provider_id".to_string()),
1613 "models should have 'provider_id' column"
1614 );
1615 assert!(
1616 columns.contains(&"client_name".to_string()),
1617 "models should have 'client_name' column"
1618 );
1619 assert!(
1620 columns.contains(&"price_input".to_string()),
1621 "models should have 'price_input' column"
1622 );
1623 assert!(
1624 columns.contains(&"price_output".to_string()),
1625 "models should have 'price_output' column"
1626 );
1627 assert!(
1628 columns.contains(&"currency".to_string()),
1629 "models should have 'currency' column"
1630 );
1631 assert!(
1632 columns.contains(&"context_window".to_string()),
1633 "models should have 'context_window' column"
1634 );
1635 }
1636
1637 #[test]
1638 fn test_models_foreign_key_to_providers() {
1639 let storage = setup_in_memory();
1640 let pool = storage.get_pool();
1641 let conn = pool.get().unwrap();
1642
1643 let mut stmt = conn.prepare("PRAGMA foreign_key_list('models')").unwrap();
1645 let fk_refs: Vec<String> = stmt
1646 .query_map([], |row| row.get::<_, String>(2))
1647 .unwrap()
1648 .filter_map(std::result::Result::ok)
1649 .collect();
1650
1651 assert!(
1652 fk_refs.contains(&"providers".to_string()),
1653 "models.provider_id should reference providers(id)"
1654 );
1655 }
1656
1657 #[test]
1660 fn test_seed_providers_populated() {
1661 let storage = setup_in_memory();
1662 let pool = storage.get_pool();
1663 let conn = pool.get().unwrap();
1664
1665 let count: i64 = conn
1666 .query_row("SELECT COUNT(*) FROM providers", [], |row| row.get(0))
1667 .unwrap();
1668 assert!(count >= 5, "should have 5 seeded providers, got {count}");
1669 }
1670
1671 #[test]
1672 fn test_seed_models_populated() {
1673 let storage = setup_in_memory();
1674 let pool = storage.get_pool();
1675 let conn = pool.get().unwrap();
1676
1677 let count: i64 = conn
1678 .query_row("SELECT COUNT(*) FROM models", [], |row| row.get(0))
1679 .unwrap();
1680 assert!(
1681 count >= 15,
1682 "should have at least 15 seeded models, got {count}"
1683 );
1684 }
1685
1686 #[test]
1687 fn test_seed_providers_include_deepseek() {
1688 let storage = setup_in_memory();
1689 let pool = storage.get_pool();
1690 let conn = pool.get().unwrap();
1691
1692 let name: String = conn
1693 .query_row(
1694 "SELECT name FROM providers WHERE id = 'deepseek'",
1695 [],
1696 |row| row.get(0),
1697 )
1698 .unwrap();
1699 assert_eq!(name, "DeepSeek");
1700 }
1701
1702 #[test]
1703 fn test_seed_models_linked_to_providers() {
1704 let storage = setup_in_memory();
1705 let pool = storage.get_pool();
1706 let conn = pool.get().unwrap();
1707
1708 let orphan_count: i64 = conn
1710 .query_row(
1711 "SELECT COUNT(*) FROM models WHERE provider_id NOT IN (SELECT id FROM providers)",
1712 [],
1713 |row| row.get(0),
1714 )
1715 .unwrap();
1716 assert_eq!(
1717 orphan_count, 0,
1718 "all models must reference a valid provider"
1719 );
1720 }
1721
1722 #[test]
1723 fn test_seed_models_include_deepseek_flash() {
1724 let storage = setup_in_memory();
1725 let pool = storage.get_pool();
1726 let conn = pool.get().unwrap();
1727
1728 let count: i64 = conn
1729 .query_row(
1730 "SELECT COUNT(*) FROM models WHERE client_name = 'deepseek-v4-flash'",
1731 [],
1732 |row| row.get(0),
1733 )
1734 .unwrap();
1735 assert_eq!(count, 1, "deepseek-v4-flash should exist in models");
1736 }
1737
1738 #[test]
1739 fn test_seed_models_include_deepseek() {
1740 let storage = setup_in_memory();
1741 let pool = storage.get_pool();
1742 let conn = pool.get().unwrap();
1743
1744 let count: i64 = conn
1745 .query_row(
1746 "SELECT COUNT(*) FROM models WHERE client_name IN ('deepseek-v4-pro', 'deepseek-v4-flash')",
1747 [],
1748 |row| row.get(0),
1749 )
1750 .unwrap();
1751 assert_eq!(count, 2, "deepseek models should be seeded");
1752 }
1753
1754 #[tokio::test]
1757 async fn test_storage_list_providers() {
1758 let storage = setup_in_memory_async().await;
1759 let providers = storage.list_providers().await.unwrap();
1760 assert!(!providers.is_empty(), "should return seeded providers");
1761 assert!(
1762 providers.iter().any(|p| p.name == "DeepSeek"),
1763 "should include DeepSeek"
1764 );
1765 assert!(
1766 providers.iter().any(|p| p.name == "Zhipu AI"),
1767 "should include Zhipu AI"
1768 );
1769 }
1770
1771 #[tokio::test]
1772 async fn test_storage_get_provider_found() {
1773 let storage = setup_in_memory_async().await;
1774 let provider = storage.get_provider("deepseek").await.unwrap();
1775 assert!(provider.is_some(), "should find deepseek provider");
1776 assert_eq!(provider.unwrap().name, "DeepSeek");
1777 }
1778
1779 #[tokio::test]
1780 async fn test_storage_get_provider_not_found() {
1781 let storage = setup_in_memory_async().await;
1782 let provider = storage.get_provider("nonexistent").await.unwrap();
1783 assert!(
1784 provider.is_none(),
1785 "should return None for unknown provider"
1786 );
1787 }
1788
1789 #[tokio::test]
1790 async fn test_storage_list_models_unfiltered() {
1791 let storage = setup_in_memory_async().await;
1792 let models = storage.list_models(None).await.unwrap();
1793 assert!(!models.is_empty(), "should return seeded models");
1794 }
1795
1796 #[tokio::test]
1797 async fn test_storage_list_models_filtered_by_provider() {
1798 let storage = setup_in_memory_async().await;
1799 let models = storage.list_models(Some("deepseek")).await.unwrap();
1800 assert!(!models.is_empty(), "should return models for deepseek");
1801 for m in &models {
1802 assert_eq!(
1803 m.provider_id, "deepseek",
1804 "all models should belong to deepseek"
1805 );
1806 }
1807 }
1808}