1#[cfg(feature = "encrypted")]
22use std::sync::Arc;
23
24use chrono::{DateTime, Utc};
25use serde::{de::DeserializeOwned, Serialize};
26use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
27use sqlx::{Row, SqlitePool};
28
29use crate::error::{KernelError, Result};
30use crate::module::{ModuleKind, ModuleMetadata, ModuleState};
31
32#[cfg(feature = "encrypted")]
44#[derive(Clone)]
45struct RegistryCipher {
46 cipher: Arc<aes_gcm::Aes256Gcm>,
47}
48
49#[cfg(feature = "encrypted")]
50impl RegistryCipher {
51 fn new(key_str: &str) -> Self {
53 use aes_gcm::aead::KeyInit;
54 use sha2::{Digest, Sha256};
55
56 let key_bytes = Sha256::digest(key_str.as_bytes());
57 let cipher = aes_gcm::Aes256Gcm::new(&key_bytes);
58 Self {
59 cipher: Arc::new(cipher),
60 }
61 }
62
63 fn encrypt(&self, plaintext: &str) -> anyhow::Result<String> {
64 use aes_gcm::aead::{AeadCore, AeadMut, OsRng};
65 use base64ct::{Base64Url, Encoding};
66
67 let nonce = aes_gcm::Aes256Gcm::generate_nonce(&mut OsRng);
68 let mut cipher = (*self.cipher).clone();
71 let ciphertext = cipher
72 .encrypt(&nonce, plaintext.as_bytes())
73 .map_err(|e| anyhow::anyhow!("encrypt: {e}"))?;
74
75 let mut blob = nonce.to_vec();
76 blob.extend_from_slice(&ciphertext);
77 Ok(Base64Url::encode_string(&blob))
78 }
79
80 fn decrypt(&self, encoded: &str) -> anyhow::Result<String> {
81 use aes_gcm::aead::AeadMut;
82 use base64ct::{Base64Url, Encoding};
83
84 let blob =
85 Base64Url::decode_vec(encoded).map_err(|e| anyhow::anyhow!("base64 decode: {e}"))?;
86 if blob.len() < 12 {
87 return Err(anyhow::anyhow!("encrypted blob too short"));
88 }
89 let (nonce_bytes, ciphertext) = blob.split_at(12);
90 let nonce = aes_gcm::Nonce::from_slice(nonce_bytes);
91 let mut cipher = (*self.cipher).clone();
92 let plaintext = cipher
93 .decrypt(nonce, ciphertext)
94 .map_err(|e| anyhow::anyhow!("decrypt: {e}"))?;
95 String::from_utf8(plaintext).map_err(|e| anyhow::anyhow!("utf8: {e}"))
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct ModuleRecord {
102 pub id: String,
104 pub name: String,
106 pub version: String,
108 pub kind: ModuleKind,
110 pub state: ModuleState,
112 pub description: Option<String>,
114 pub updated_at: DateTime<Utc>,
116}
117
118#[derive(Clone)]
126pub struct StateRegistry {
127 pool: SqlitePool,
128 #[cfg(feature = "encrypted")]
129 cipher: Option<RegistryCipher>,
130}
131
132impl StateRegistry {
133 pub async fn in_memory() -> Result<Self> {
136 let options = SqliteConnectOptions::new()
137 .in_memory(true)
138 .create_if_missing(true);
139 let pool = SqlitePoolOptions::new()
140 .max_connections(1)
141 .connect_with(options)
142 .await?;
143
144 let registry = Self {
145 pool,
146 #[cfg(feature = "encrypted")]
147 cipher: None,
148 };
149 registry.migrate().await?;
150 Ok(registry)
151 }
152
153 pub async fn connect(path: &str) -> Result<Self> {
155 let options = SqliteConnectOptions::new()
156 .filename(path)
157 .create_if_missing(true);
158 let pool = SqlitePoolOptions::new()
159 .max_connections(5)
160 .connect_with(options)
161 .await?;
162
163 let registry = Self {
164 pool,
165 #[cfg(feature = "encrypted")]
166 cipher: None,
167 };
168 registry.migrate().await?;
169 Ok(registry)
170 }
171
172 #[cfg(feature = "encrypted")]
182 pub async fn open_encrypted(path: &str, key_str: &str) -> Result<Self> {
183 let options = SqliteConnectOptions::new()
184 .filename(path)
185 .create_if_missing(true);
186 let pool = SqlitePoolOptions::new()
187 .max_connections(5)
188 .connect_with(options)
189 .await?;
190
191 let registry = Self {
192 pool,
193 cipher: Some(RegistryCipher::new(key_str)),
194 };
195 registry.migrate().await?;
196 Ok(registry)
197 }
198
199 #[cfg(feature = "encrypted")]
204 pub fn key_from_env() -> Result<String> {
205 std::env::var("OXIDE_REGISTRY_KEY").map_err(|_| {
206 KernelError::Other(anyhow::anyhow!(
207 "OXIDE_REGISTRY_KEY env var not set; \
208 set it or call open_encrypted with an explicit key"
209 ))
210 })
211 }
212
213 pub fn pool(&self) -> &sqlx::SqlitePool {
217 &self.pool
218 }
219
220 async fn migrate(&self) -> Result<()> {
222 sqlx::query(
223 r#"
224 CREATE TABLE IF NOT EXISTS modules (
225 id TEXT PRIMARY KEY,
226 name TEXT NOT NULL,
227 version TEXT NOT NULL,
228 kind TEXT NOT NULL,
229 state TEXT NOT NULL,
230 description TEXT,
231 updated_at TEXT NOT NULL
232 )
233 "#,
234 )
235 .execute(&self.pool)
236 .await?;
237
238 sqlx::query(
239 r#"
240 CREATE TABLE IF NOT EXISTS config (
241 key TEXT PRIMARY KEY,
242 value TEXT NOT NULL,
243 updated_at TEXT NOT NULL
244 )
245 "#,
246 )
247 .execute(&self.pool)
248 .await?;
249
250 Ok(())
251 }
252
253 pub async fn upsert_module(&self, metadata: &ModuleMetadata, state: ModuleState) -> Result<()> {
259 let now = Utc::now().to_rfc3339();
260 let kind = serde_json::to_string(&metadata.kind)?;
261 let state_str = serde_json::to_string(&state)?;
262
263 sqlx::query(
264 r#"
265 INSERT INTO modules (id, name, version, kind, state, description, updated_at)
266 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
267 ON CONFLICT(id) DO UPDATE SET
268 name = excluded.name,
269 version = excluded.version,
270 kind = excluded.kind,
271 state = excluded.state,
272 description = excluded.description,
273 updated_at = excluded.updated_at
274 "#,
275 )
276 .bind(&metadata.id)
277 .bind(&metadata.name)
278 .bind(&metadata.version)
279 .bind(kind)
280 .bind(state_str)
281 .bind(metadata.description.as_deref())
282 .bind(now)
283 .execute(&self.pool)
284 .await?;
285
286 Ok(())
287 }
288
289 pub async fn set_module_state(&self, id: &str, state: ModuleState) -> Result<()> {
291 let state_str = serde_json::to_string(&state)?;
292 let now = Utc::now().to_rfc3339();
293 let res = sqlx::query("UPDATE modules SET state = ?1, updated_at = ?2 WHERE id = ?3")
294 .bind(state_str)
295 .bind(now)
296 .bind(id)
297 .execute(&self.pool)
298 .await?;
299
300 if res.rows_affected() == 0 {
301 return Err(KernelError::UnknownModule(id.to_string()));
302 }
303 Ok(())
304 }
305
306 pub async fn get_module(&self, id: &str) -> Result<Option<ModuleRecord>> {
308 let row = sqlx::query(
309 "SELECT id, name, version, kind, state, description, updated_at FROM modules WHERE id = ?1",
310 )
311 .bind(id)
312 .fetch_optional(&self.pool)
313 .await?;
314
315 row.map(row_to_module_record).transpose()
316 }
317
318 pub async fn list_modules(&self) -> Result<Vec<ModuleRecord>> {
320 let rows = sqlx::query(
321 "SELECT id, name, version, kind, state, description, updated_at FROM modules ORDER BY id",
322 )
323 .fetch_all(&self.pool)
324 .await?;
325
326 rows.into_iter().map(row_to_module_record).collect()
327 }
328
329 pub async fn set_config<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
338 let json = serde_json::to_string(value)?;
339
340 #[cfg(feature = "encrypted")]
341 let stored = if let Some(cipher) = &self.cipher {
342 cipher
343 .encrypt(&json)
344 .map_err(|e| KernelError::Other(anyhow::anyhow!("config encrypt: {e}")))?
345 } else {
346 json
347 };
348
349 #[cfg(not(feature = "encrypted"))]
350 let stored = json;
351
352 let now = Utc::now().to_rfc3339();
353 sqlx::query(
354 r#"
355 INSERT INTO config (key, value, updated_at)
356 VALUES (?1, ?2, ?3)
357 ON CONFLICT(key) DO UPDATE SET
358 value = excluded.value,
359 updated_at = excluded.updated_at
360 "#,
361 )
362 .bind(key)
363 .bind(stored)
364 .bind(now)
365 .execute(&self.pool)
366 .await?;
367 Ok(())
368 }
369
370 pub async fn get_config<T: DeserializeOwned>(&self, key: &str) -> Result<T> {
375 let row = sqlx::query("SELECT value FROM config WHERE key = ?1")
376 .bind(key)
377 .fetch_optional(&self.pool)
378 .await?;
379
380 let Some(row) = row else {
381 return Err(KernelError::ConfigNotFound(key.to_string()));
382 };
383
384 let stored: String = row.try_get("value").map_err(KernelError::Registry)?;
385
386 #[cfg(feature = "encrypted")]
387 let json = if let Some(cipher) = &self.cipher {
388 cipher
389 .decrypt(&stored)
390 .map_err(|e| KernelError::Other(anyhow::anyhow!("config decrypt: {e}")))?
391 } else {
392 stored
393 };
394
395 #[cfg(not(feature = "encrypted"))]
396 let json = stored;
397
398 let parsed: T = serde_json::from_str(&json)?;
399 Ok(parsed)
400 }
401
402 pub async fn try_get_config<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
404 match self.get_config::<T>(key).await {
405 Ok(v) => Ok(Some(v)),
406 Err(KernelError::ConfigNotFound(_)) => Ok(None),
407 Err(e) => Err(e),
408 }
409 }
410
411 pub async fn delete_config(&self, key: &str) -> Result<bool> {
413 let res = sqlx::query("DELETE FROM config WHERE key = ?1")
414 .bind(key)
415 .execute(&self.pool)
416 .await?;
417 Ok(res.rows_affected() > 0)
418 }
419}
420
421fn row_to_module_record(row: sqlx::sqlite::SqliteRow) -> Result<ModuleRecord> {
422 let kind: String = row.try_get("kind").map_err(KernelError::Registry)?;
423 let state: String = row.try_get("state").map_err(KernelError::Registry)?;
424 let updated_at: String = row.try_get("updated_at").map_err(KernelError::Registry)?;
425
426 let updated_at = DateTime::parse_from_rfc3339(&updated_at)
427 .map_err(|e| KernelError::Other(anyhow::anyhow!("invalid updated_at: {e}")))?
428 .with_timezone(&Utc);
429
430 Ok(ModuleRecord {
431 id: row.try_get("id").map_err(KernelError::Registry)?,
432 name: row.try_get("name").map_err(KernelError::Registry)?,
433 version: row.try_get("version").map_err(KernelError::Registry)?,
434 kind: serde_json::from_str(&kind)?,
435 state: serde_json::from_str(&state)?,
436 description: row.try_get("description").map_err(KernelError::Registry)?,
437 updated_at,
438 })
439}
440
441#[cfg(all(test, feature = "encrypted"))]
446mod encrypted_tests {
447 use super::*;
448 use tempfile::NamedTempFile;
449
450 #[tokio::test]
451 async fn encrypted_config_round_trips() {
452 let tmp = NamedTempFile::new().unwrap();
453 let path = tmp.path().to_str().unwrap();
454
455 {
456 let reg = StateRegistry::open_encrypted(path, "s3cr3t-key")
457 .await
458 .unwrap();
459 reg.set_config("api_token", &"my-super-secret-token")
460 .await
461 .unwrap();
462 let v: String = reg.get_config("api_token").await.unwrap();
463 assert_eq!(v, "my-super-secret-token");
464 }
465
466 let reg2 = StateRegistry::open_encrypted(path, "s3cr3t-key")
468 .await
469 .unwrap();
470 let v2: String = reg2.get_config("api_token").await.unwrap();
471 assert_eq!(v2, "my-super-secret-token");
472 }
473
474 #[tokio::test]
475 async fn encrypted_value_not_readable_as_plain_json() {
476 let tmp = NamedTempFile::new().unwrap();
477 let reg = StateRegistry::open_encrypted(tmp.path().to_str().unwrap(), "key")
478 .await
479 .unwrap();
480 reg.set_config("secret", &42i32).await.unwrap();
481
482 let plain = StateRegistry::connect(tmp.path().to_str().unwrap())
484 .await
485 .unwrap();
486 let err = plain.get_config::<i32>("secret").await.unwrap_err();
488 assert!(
489 matches!(err, KernelError::Serde(_)),
490 "expected Serde error, got {err}"
491 );
492 }
493}
494
495impl std::fmt::Debug for StateRegistry {
496 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
497 f.debug_struct("StateRegistry").finish_non_exhaustive()
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 fn sample_meta(id: &str) -> ModuleMetadata {
506 ModuleMetadata {
507 id: id.into(),
508 name: format!("Module {id}"),
509 version: "0.1.0".into(),
510 kind: ModuleKind::Native,
511 description: Some("test module".into()),
512 }
513 }
514
515 #[tokio::test]
516 async fn in_memory_registry_runs_migrations() {
517 let reg = StateRegistry::in_memory().await.unwrap();
518 assert!(reg.list_modules().await.unwrap().is_empty());
520 }
521
522 #[tokio::test]
523 async fn upsert_and_get_module() {
524 let reg = StateRegistry::in_memory().await.unwrap();
525 let meta = sample_meta("mirror");
526 reg.upsert_module(&meta, ModuleState::Loaded).await.unwrap();
527
528 let rec = reg.get_module("mirror").await.unwrap().expect("record");
529 assert_eq!(rec.id, "mirror");
530 assert_eq!(rec.kind, ModuleKind::Native);
531 assert_eq!(rec.state, ModuleState::Loaded);
532
533 reg.upsert_module(&meta, ModuleState::Running)
535 .await
536 .unwrap();
537 let rec = reg.get_module("mirror").await.unwrap().expect("record");
538 assert_eq!(rec.state, ModuleState::Running);
539 }
540
541 #[tokio::test]
542 async fn set_module_state_updates_only_state() {
543 let reg = StateRegistry::in_memory().await.unwrap();
544 let meta = sample_meta("compress");
545 reg.upsert_module(&meta, ModuleState::Loaded).await.unwrap();
546
547 reg.set_module_state("compress", ModuleState::Running)
548 .await
549 .unwrap();
550 let rec = reg.get_module("compress").await.unwrap().unwrap();
551 assert_eq!(rec.state, ModuleState::Running);
552 assert_eq!(rec.name, meta.name);
553 }
554
555 #[tokio::test]
556 async fn set_module_state_unknown_errors() {
557 let reg = StateRegistry::in_memory().await.unwrap();
558 let err = reg
559 .set_module_state("missing", ModuleState::Running)
560 .await
561 .unwrap_err();
562 assert!(matches!(err, KernelError::UnknownModule(_)));
563 }
564
565 #[tokio::test]
566 async fn config_round_trip() {
567 let reg = StateRegistry::in_memory().await.unwrap();
568 reg.set_config("max_threads", &8u32).await.unwrap();
569 let val: u32 = reg.get_config("max_threads").await.unwrap();
570 assert_eq!(val, 8);
571
572 reg.set_config("max_threads", &16u32).await.unwrap();
574 let val: u32 = reg.get_config("max_threads").await.unwrap();
575 assert_eq!(val, 16);
576 }
577
578 #[tokio::test]
579 async fn config_missing_returns_not_found() {
580 let reg = StateRegistry::in_memory().await.unwrap();
581 let err = reg.get_config::<String>("missing").await.unwrap_err();
582 assert!(matches!(err, KernelError::ConfigNotFound(_)));
583
584 let opt: Option<String> = reg.try_get_config("missing").await.unwrap();
585 assert!(opt.is_none());
586 }
587
588 #[tokio::test]
589 async fn delete_config_returns_whether_removed() {
590 let reg = StateRegistry::in_memory().await.unwrap();
591 reg.set_config("key", &"value").await.unwrap();
592 assert!(reg.delete_config("key").await.unwrap());
593 assert!(!reg.delete_config("key").await.unwrap());
594 }
595
596 #[tokio::test]
597 async fn config_supports_complex_types() {
598 #[derive(Serialize, serde::Deserialize, PartialEq, Debug)]
599 struct Endpoint {
600 url: String,
601 retries: u8,
602 }
603
604 let reg = StateRegistry::in_memory().await.unwrap();
605 let ep = Endpoint {
606 url: "https://example.com".into(),
607 retries: 3,
608 };
609 reg.set_config("endpoint", &ep).await.unwrap();
610 let round: Endpoint = reg.get_config("endpoint").await.unwrap();
611 assert_eq!(round, ep);
612 }
613}