1use serde::{Deserialize, Serialize};
79use std::fs::File;
80use std::io::{self, BufReader};
81use std::path::{Path, PathBuf};
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ConfigFile {
97 #[serde(default = "default_auto_pull")]
103 pub auto_pull: bool,
104
105 #[serde(default)]
109 pub models_home: Option<PathBuf>,
110
111 #[serde(default)]
115 pub model: Option<ModelConfig>,
116
117 #[serde(default = "default_n_ctx")]
121 pub n_ctx: u32,
122
123 #[serde(default)]
127 pub n_gpu_layers: i32,
128
129 #[serde(default)]
132 pub admin_addr: Option<String>,
133
134 #[serde(default)]
138 pub backends: Option<Vec<BackendEntry>>,
139
140 #[serde(default)]
147 pub listen: Option<ListenConfig>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, Default)]
158pub struct ListenConfig {
159 #[serde(default)]
166 pub tcp: Option<String>,
167
168 #[serde(default)]
171 pub tcp_v2: Option<String>,
172
173 #[serde(default)]
177 pub tcp_embed: Option<String>,
178
179 #[serde(default)]
187 pub api_key_env: Option<String>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
195#[serde(tag = "kind", rename_all = "kebab-case")]
196pub enum BackendEntry {
197 Llamacpp(LlamacppEntry),
199 OpenaiCompat(OpenaiCompatEntry),
204 BedrockInvoke(BedrockInvokeEntry),
210}
211
212impl BackendEntry {
213 pub fn name(&self) -> &str {
216 match self {
217 BackendEntry::Llamacpp(e) => &e.name,
218 BackendEntry::OpenaiCompat(e) => &e.name,
219 BackendEntry::BedrockInvoke(e) => &e.name,
220 }
221 }
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct LlamacppEntry {
227 pub name: String,
231
232 pub model: ModelConfig,
234
235 #[serde(default = "default_n_ctx")]
237 pub n_ctx: u32,
238
239 #[serde(default)]
241 pub n_gpu_layers: i32,
242
243 #[serde(default)]
249 pub embed: bool,
250
251 #[serde(default, skip_serializing_if = "Option::is_none")]
258 pub embed_pooling: Option<i32>,
259
260 #[serde(default = "default_embed_n_ctx")]
266 pub embed_n_ctx: u32,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct OpenaiCompatEntry {
272 pub name: String,
274
275 pub base_url: String,
280
281 pub model: String,
285
286 #[serde(default)]
293 pub api_key_env: Option<String>,
294
295 #[serde(default = "default_openai_timeout_secs")]
297 pub timeout_secs: u64,
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct BedrockInvokeEntry {
316 pub name: String,
318
319 pub region: String,
322
323 pub model_id: String,
327
328 #[serde(default)]
335 pub bearer_token_env: Option<String>,
336
337 #[serde(default)]
341 pub endpoint: Option<String>,
342
343 #[serde(default = "default_bedrock_timeout_secs")]
345 pub timeout_secs: u64,
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct ModelConfig {
359 pub name: String,
362 pub sha256: String,
364 #[serde(default)]
366 pub size_bytes: Option<u64>,
367 pub source_url: String,
369 #[serde(default)]
371 pub license: Option<String>,
372}
373
374fn default_auto_pull() -> bool {
375 true
376}
377
378fn default_n_ctx() -> u32 {
379 8192
380}
381
382fn default_embed_n_ctx() -> u32 {
383 2048
384}
385
386fn default_openai_timeout_secs() -> u64 {
387 300
388}
389
390fn default_bedrock_timeout_secs() -> u64 {
391 300
392}
393
394fn home_dir() -> Option<PathBuf> {
395 #[cfg(unix)]
396 {
397 std::env::var_os("HOME").map(PathBuf::from)
398 }
399 #[cfg(not(unix))]
400 {
401 std::env::var_os("USERPROFILE").map(PathBuf::from)
402 }
403}
404
405pub fn default_config_path() -> PathBuf {
409 if let Ok(p) = std::env::var("INFERD_CONFIG") {
410 return PathBuf::from(p);
411 }
412 let home = home_dir().unwrap_or_else(|| PathBuf::from("."));
413 home.join(".inferd").join("config.json")
414}
415
416pub fn default_first_boot_config() -> ConfigFile {
430 ConfigFile {
431 auto_pull: true,
432 models_home: None,
433 model: None,
434 n_ctx: default_n_ctx(),
435 n_gpu_layers: 0,
436 admin_addr: None,
437 backends: Some(vec![
438 BackendEntry::Llamacpp(LlamacppEntry {
439 name: "gemma-4-e4b".into(),
440 model: ModelConfig {
441 name: "gemma-4-e4b".into(),
442 sha256: "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36"
443 .into(),
444 size_bytes: Some(5_126_304_928),
445 source_url: "https://huggingface.co/unsloth/gemma-4-E4B-it-GGUF/resolve/main/\
446 gemma-4-E4B-it-UD-Q4_K_XL.gguf"
447 .into(),
448 license: Some("gemma".into()),
449 },
450 n_ctx: default_n_ctx(),
451 n_gpu_layers: 0,
452 embed: false,
453 embed_pooling: None,
454 embed_n_ctx: default_embed_n_ctx(),
455 }),
456 BackendEntry::Llamacpp(LlamacppEntry {
457 name: "embeddinggemma-300m".into(),
458 model: ModelConfig {
459 name: "embeddinggemma-300m".into(),
460 sha256: "a0f7b4e13c397a6e1b32c2de75b1f65a14c92ec524d5f674d94a4290a1c4969b"
461 .into(),
462 size_bytes: Some(328_577_056),
463 source_url:
464 "https://huggingface.co/unsloth/embeddinggemma-300m-GGUF/resolve/main/\
465 embeddinggemma-300M-Q8_0.gguf"
466 .into(),
467 license: Some("gemma".into()),
468 },
469 n_ctx: default_embed_n_ctx(),
470 n_gpu_layers: 0,
471 embed: true,
472 embed_pooling: None,
473 embed_n_ctx: default_embed_n_ctx(),
474 }),
475 ]),
476 listen: None,
477 }
478}
479
480pub fn write_default_if_missing(path: &Path) -> io::Result<bool> {
489 if path.exists() {
490 return Ok(false);
491 }
492 if let Some(parent) = path.parent() {
493 std::fs::create_dir_all(parent)?;
494 }
495 let cfg = default_first_boot_config();
496 let tmp = path.with_extension("json.tmp");
497 {
498 let file = File::create(&tmp)?;
499 let mut writer = std::io::BufWriter::new(file);
500 serde_json::to_writer_pretty(&mut writer, &cfg)
501 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
502 std::io::Write::write_all(&mut writer, b"\n")?;
503 std::io::Write::flush(&mut writer)?;
504 }
505 std::fs::rename(&tmp, path)?;
506 Ok(true)
507}
508
509#[derive(Debug, thiserror::Error)]
511pub enum ConfigError {
512 #[error("config file not found: {0}")]
514 NotFound(PathBuf),
515 #[error("io reading {path}: {source}")]
517 Io {
518 path: PathBuf,
520 #[source]
522 source: io::Error,
523 },
524 #[error("parse {path}: {source}")]
526 Parse {
527 path: PathBuf,
529 #[source]
531 source: serde_json::Error,
532 },
533 #[error("invalid config: {0}")]
535 Invalid(String),
536}
537
538impl ConfigFile {
539 pub fn load(path: &Path) -> Result<Self, ConfigError> {
541 let file = File::open(path).map_err(|e| {
542 if e.kind() == io::ErrorKind::NotFound {
543 ConfigError::NotFound(path.to_path_buf())
544 } else {
545 ConfigError::Io {
546 path: path.to_path_buf(),
547 source: e,
548 }
549 }
550 })?;
551 let reader = BufReader::new(file);
552 let mut cfg: ConfigFile =
553 serde_json::from_reader(reader).map_err(|e| ConfigError::Parse {
554 path: path.to_path_buf(),
555 source: e,
556 })?;
557 cfg.expand_paths();
558 cfg.validate()?;
559 Ok(cfg)
560 }
561
562 fn expand_paths(&mut self) {
563 if let Some(p) = self.models_home.as_ref()
564 && let Some(stripped) = p
565 .to_str()
566 .and_then(|s| s.strip_prefix("~/").or_else(|| s.strip_prefix("~\\")))
567 && let Some(home) = home_dir()
568 {
569 self.models_home = Some(home.join(stripped));
570 }
571 }
572
573 fn validate(&self) -> Result<(), ConfigError> {
574 match (&self.model, &self.backends) {
575 (Some(_), Some(_)) => {
576 return Err(ConfigError::Invalid(
577 "config: `model` and `backends` are mutually exclusive — \
578 pick one shape, not both"
579 .into(),
580 ));
581 }
582 (None, None) => {
583 return Err(ConfigError::Invalid(
584 "config: must specify either `model` (legacy single-backend) \
585 or `backends` (multi-backend list)"
586 .into(),
587 ));
588 }
589 _ => {}
590 }
591 if self.n_ctx == 0 {
592 return Err(ConfigError::Invalid("n_ctx must be > 0".into()));
593 }
594 if let Some(m) = &self.model {
595 validate_model_config(m)?;
596 }
597 if let Some(listen) = &self.listen {
598 if let Some(addr) = &listen.tcp
599 && addr.trim().is_empty()
600 {
601 return Err(ConfigError::Invalid(
602 "listen.tcp must not be empty when set".into(),
603 ));
604 }
605 if let Some(addr) = &listen.tcp_v2
606 && addr.trim().is_empty()
607 {
608 return Err(ConfigError::Invalid(
609 "listen.tcp_v2 must not be empty when set".into(),
610 ));
611 }
612 if let Some(addr) = &listen.tcp_embed
613 && addr.trim().is_empty()
614 {
615 return Err(ConfigError::Invalid(
616 "listen.tcp_embed must not be empty when set".into(),
617 ));
618 }
619 }
620 if let Some(list) = &self.backends {
621 if list.is_empty() {
622 return Err(ConfigError::Invalid(
623 "backends list must not be empty".into(),
624 ));
625 }
626 let mut seen = std::collections::HashSet::with_capacity(list.len());
627 for entry in list {
628 let name = entry.name();
629 if name.is_empty() {
630 return Err(ConfigError::Invalid(
631 "backends[].name must not be empty".into(),
632 ));
633 }
634 if !seen.insert(name.to_string()) {
635 return Err(ConfigError::Invalid(format!(
636 "duplicate backends[].name {name:?} — names must be unique"
637 )));
638 }
639 match entry {
640 BackendEntry::Llamacpp(e) => {
641 validate_model_config(&e.model)?;
642 if e.n_ctx == 0 {
643 return Err(ConfigError::Invalid(format!(
644 "backends[{name:?}].n_ctx must be > 0"
645 )));
646 }
647 }
648 BackendEntry::OpenaiCompat(e) => {
649 if e.base_url.trim().is_empty() {
650 return Err(ConfigError::Invalid(format!(
651 "backends[{name:?}].base_url must not be empty"
652 )));
653 }
654 if !(e.base_url.starts_with("https://")
655 || e.base_url.starts_with("http://"))
656 {
657 return Err(ConfigError::Invalid(format!(
658 "backends[{name:?}].base_url must be http:// or https:// \
659 (got {:?})",
660 e.base_url
661 )));
662 }
663 if e.model.trim().is_empty() {
664 return Err(ConfigError::Invalid(format!(
665 "backends[{name:?}].model must not be empty"
666 )));
667 }
668 if e.timeout_secs == 0 {
669 return Err(ConfigError::Invalid(format!(
670 "backends[{name:?}].timeout_secs must be > 0"
671 )));
672 }
673 }
674 BackendEntry::BedrockInvoke(e) => {
675 if e.region.trim().is_empty() {
676 return Err(ConfigError::Invalid(format!(
677 "backends[{name:?}].region must not be empty"
678 )));
679 }
680 if e.model_id.trim().is_empty() {
681 return Err(ConfigError::Invalid(format!(
682 "backends[{name:?}].model_id must not be empty"
683 )));
684 }
685 if e.timeout_secs == 0 {
686 return Err(ConfigError::Invalid(format!(
687 "backends[{name:?}].timeout_secs must be > 0"
688 )));
689 }
690 }
691 }
692 }
693 }
694 Ok(())
695 }
696
697 pub fn resolved_backends(&self) -> Vec<BackendEntry> {
702 if let Some(list) = &self.backends {
703 return list.clone();
704 }
705 let m = self
709 .model
710 .as_ref()
711 .expect("validate() guarantees one of model|backends is set")
712 .clone();
713 vec![BackendEntry::Llamacpp(LlamacppEntry {
714 name: m.name.clone(),
715 model: m,
716 n_ctx: self.n_ctx,
717 n_gpu_layers: self.n_gpu_layers,
718 embed: false,
723 embed_pooling: None,
724 embed_n_ctx: default_embed_n_ctx(),
725 })]
726 }
727}
728
729fn validate_model_config(m: &ModelConfig) -> Result<(), ConfigError> {
730 if m.name.is_empty() {
731 return Err(ConfigError::Invalid("model.name must not be empty".into()));
732 }
733 if !m.source_url.starts_with("https://") {
734 return Err(ConfigError::Invalid(format!(
735 "model.source_url must be https:// (got {:?})",
736 m.source_url
737 )));
738 }
739 if m.sha256.len() != 64
740 || !m
741 .sha256
742 .bytes()
743 .all(|b| b.is_ascii_hexdigit() && !b.is_ascii_uppercase())
744 {
745 return Err(ConfigError::Invalid(
746 "model.sha256 must be 64 lowercase hex chars".into(),
747 ));
748 }
749 Ok(())
750}
751
752impl From<&ModelConfig> for crate::fetch::ModelSpec {
753 fn from(m: &ModelConfig) -> Self {
754 crate::fetch::ModelSpec {
755 name: m.name.clone(),
756 source_url: m.source_url.clone(),
757 sha256_hex: m.sha256.clone(),
758 size_bytes: m.size_bytes,
759 license: m.license.clone(),
760 source: None,
761 }
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use std::io::Write;
769
770 fn write_config(s: &str) -> tempfile::NamedTempFile {
771 let mut f = tempfile::NamedTempFile::new().unwrap();
772 f.write_all(s.as_bytes()).unwrap();
773 f.flush().unwrap();
774 f
775 }
776
777 fn good_json() -> String {
778 r#"{
779 "auto_pull": true,
780 "models_home": "/tmp/inferd-models-home",
781 "model": {
782 "name": "gemma-4-e4b",
783 "sha256": "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
784 "size_bytes": 5126304928,
785 "source_url": "https://huggingface.co/unsloth/gemma-4-E4B-it-GGUF/resolve/main/gemma-4-E4B-it-UD-Q4_K_XL.gguf",
786 "license": "apache-2.0"
787 },
788 "n_ctx": 8192,
789 "n_gpu_layers": 0
790 }"#
791 .to_string()
792 }
793
794 #[test]
795 fn load_well_formed_config() {
796 let f = write_config(&good_json());
797 let cfg = ConfigFile::load(f.path()).unwrap();
798 let m = cfg.model.as_ref().expect("legacy model present");
799 assert_eq!(m.name, "gemma-4-e4b");
800 assert_eq!(m.size_bytes, Some(5_126_304_928));
801 assert_eq!(m.license.as_deref(), Some("apache-2.0"));
802 assert!(cfg.auto_pull);
803 assert_eq!(cfg.n_ctx, 8192);
804 assert_eq!(
805 cfg.models_home,
806 Some(PathBuf::from("/tmp/inferd-models-home"))
807 );
808 }
809
810 #[test]
811 fn missing_file_returns_not_found() {
812 let path = std::env::temp_dir().join("inferd-config-does-not-exist.json");
813 let _ = std::fs::remove_file(&path);
814 let err = ConfigFile::load(&path).unwrap_err();
815 assert!(matches!(err, ConfigError::NotFound(_)));
816 }
817
818 #[test]
819 fn invalid_json_returns_parse_error() {
820 let f = write_config("{ not valid json");
821 let err = ConfigFile::load(f.path()).unwrap_err();
822 assert!(matches!(err, ConfigError::Parse { .. }));
823 }
824
825 #[test]
826 fn http_url_rejected() {
827 let bad = good_json().replace("https://", "http://");
828 let f = write_config(&bad);
829 let err = ConfigFile::load(f.path()).unwrap_err();
830 match err {
831 ConfigError::Invalid(msg) => assert!(msg.contains("https://")),
832 other => panic!("expected Invalid, got {other:?}"),
833 }
834 }
835
836 #[test]
837 fn uppercase_sha_rejected() {
838 let bad = good_json().replace(
839 "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
840 "30D1E7949597A3446726064E80B876FD1B5CBA4AA6EEC53D27AFA420E731FB36",
841 );
842 let f = write_config(&bad);
843 let err = ConfigFile::load(f.path()).unwrap_err();
844 match err {
845 ConfigError::Invalid(msg) => assert!(msg.contains("lowercase hex")),
846 other => panic!("expected Invalid, got {other:?}"),
847 }
848 }
849
850 #[test]
851 fn short_sha_rejected() {
852 let bad = good_json().replace(
853 "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
854 "30d1e7",
855 );
856 let f = write_config(&bad);
857 let err = ConfigFile::load(f.path()).unwrap_err();
858 assert!(matches!(err, ConfigError::Invalid(_)));
859 }
860
861 #[test]
862 fn defaults_when_optional_fields_missing() {
863 let json = r#"{
864 "model": {
865 "name": "gemma-4-e4b",
866 "sha256": "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
867 "source_url": "https://example.com/x.gguf"
868 }
869 }"#;
870 let f = write_config(json);
871 let cfg = ConfigFile::load(f.path()).unwrap();
872 let m = cfg.model.as_ref().expect("legacy model present");
873 assert!(cfg.auto_pull);
874 assert_eq!(cfg.n_ctx, 8192);
875 assert_eq!(cfg.n_gpu_layers, 0);
876 assert!(m.size_bytes.is_none());
877 assert!(cfg.models_home.is_none());
878 assert!(m.license.is_none());
879 }
880
881 #[test]
882 fn modelconfig_converts_to_fetch_modelspec() {
883 let cfg = ModelConfig {
884 name: "x".into(),
885 sha256: "abc".into(),
886 size_bytes: Some(42),
887 source_url: "https://e/x.gguf".into(),
888 license: Some("mit".into()),
889 };
890 let spec: crate::fetch::ModelSpec = (&cfg).into();
891 assert_eq!(spec.name, "x");
892 assert_eq!(spec.size_bytes, Some(42));
893 assert_eq!(spec.sha256_hex, "abc");
894 assert_eq!(spec.license.as_deref(), Some("mit"));
895 }
896
897 fn good_multi_backend_json() -> String {
898 r#"{
899 "models_home": "/tmp/inferd-models-home",
900 "backends": [
901 {
902 "kind": "llamacpp",
903 "name": "local-gemma",
904 "model": {
905 "name": "gemma-4-e4b",
906 "sha256": "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
907 "source_url": "https://example.com/gemma.gguf"
908 },
909 "n_ctx": 8192,
910 "n_gpu_layers": 35
911 },
912 {
913 "kind": "openai-compat",
914 "name": "anthropic-fallback",
915 "base_url": "https://api.anthropic.com",
916 "model": "claude-opus-4-7",
917 "api_key_env": "ANTHROPIC_API_KEY"
918 }
919 ]
920 }"#
921 .to_string()
922 }
923
924 #[test]
925 fn load_multi_backend_config() {
926 let f = write_config(&good_multi_backend_json());
927 let cfg = ConfigFile::load(f.path()).unwrap();
928 assert!(cfg.model.is_none());
929 let list = cfg.backends.as_ref().expect("backends present");
930 assert_eq!(list.len(), 2);
931 match &list[0] {
932 BackendEntry::Llamacpp(e) => {
933 assert_eq!(e.name, "local-gemma");
934 assert_eq!(e.model.name, "gemma-4-e4b");
935 assert_eq!(e.n_ctx, 8192);
936 assert_eq!(e.n_gpu_layers, 35);
937 }
938 other => panic!("expected llamacpp, got {other:?}"),
939 }
940 match &list[1] {
941 BackendEntry::OpenaiCompat(e) => {
942 assert_eq!(e.name, "anthropic-fallback");
943 assert_eq!(e.base_url, "https://api.anthropic.com");
944 assert_eq!(e.model, "claude-opus-4-7");
945 assert_eq!(e.api_key_env.as_deref(), Some("ANTHROPIC_API_KEY"));
946 assert_eq!(e.timeout_secs, 300);
947 }
948 other => panic!("expected openai-compat, got {other:?}"),
949 }
950 }
951
952 #[test]
953 fn rejects_both_model_and_backends() {
954 let json = r#"{
955 "model": {
956 "name": "gemma-4-e4b",
957 "sha256": "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
958 "source_url": "https://example.com/x.gguf"
959 },
960 "backends": [
961 {
962 "kind": "openai-compat",
963 "name": "x",
964 "base_url": "https://api.openai.com",
965 "model": "gpt-4o-mini"
966 }
967 ]
968 }"#;
969 let f = write_config(json);
970 let err = ConfigFile::load(f.path()).unwrap_err();
971 match err {
972 ConfigError::Invalid(msg) => assert!(msg.contains("mutually exclusive")),
973 other => panic!("expected Invalid, got {other:?}"),
974 }
975 }
976
977 #[test]
978 fn rejects_neither_model_nor_backends() {
979 let json = r#"{ "auto_pull": true }"#;
980 let f = write_config(json);
981 let err = ConfigFile::load(f.path()).unwrap_err();
982 match err {
983 ConfigError::Invalid(msg) => assert!(msg.contains("must specify either")),
984 other => panic!("expected Invalid, got {other:?}"),
985 }
986 }
987
988 #[test]
989 fn rejects_empty_backends_list() {
990 let json = r#"{ "backends": [] }"#;
991 let f = write_config(json);
992 let err = ConfigFile::load(f.path()).unwrap_err();
993 match err {
994 ConfigError::Invalid(msg) => assert!(msg.contains("must not be empty")),
995 other => panic!("expected Invalid, got {other:?}"),
996 }
997 }
998
999 #[test]
1000 fn rejects_duplicate_backend_names() {
1001 let json = r#"{
1002 "backends": [
1003 {
1004 "kind": "openai-compat",
1005 "name": "dup",
1006 "base_url": "https://api.openai.com",
1007 "model": "gpt-4o-mini"
1008 },
1009 {
1010 "kind": "openai-compat",
1011 "name": "dup",
1012 "base_url": "https://api.anthropic.com",
1013 "model": "claude-opus-4-7"
1014 }
1015 ]
1016 }"#;
1017 let f = write_config(json);
1018 let err = ConfigFile::load(f.path()).unwrap_err();
1019 match err {
1020 ConfigError::Invalid(msg) => assert!(msg.contains("duplicate")),
1021 other => panic!("expected Invalid, got {other:?}"),
1022 }
1023 }
1024
1025 #[test]
1026 fn rejects_openai_compat_without_base_url() {
1027 let json = r#"{
1028 "backends": [
1029 {
1030 "kind": "openai-compat",
1031 "name": "x",
1032 "base_url": "",
1033 "model": "gpt-4o-mini"
1034 }
1035 ]
1036 }"#;
1037 let f = write_config(json);
1038 let err = ConfigFile::load(f.path()).unwrap_err();
1039 assert!(matches!(err, ConfigError::Invalid(_)));
1040 }
1041
1042 #[test]
1043 fn rejects_openai_compat_with_bad_scheme() {
1044 let json = r#"{
1045 "backends": [
1046 {
1047 "kind": "openai-compat",
1048 "name": "x",
1049 "base_url": "ftp://api.openai.com",
1050 "model": "gpt-4o-mini"
1051 }
1052 ]
1053 }"#;
1054 let f = write_config(json);
1055 let err = ConfigFile::load(f.path()).unwrap_err();
1056 match err {
1057 ConfigError::Invalid(msg) => assert!(msg.contains("http")),
1058 other => panic!("expected Invalid, got {other:?}"),
1059 }
1060 }
1061
1062 #[test]
1063 fn accepts_openai_compat_with_localhost_http() {
1064 let json = r#"{
1065 "backends": [
1066 {
1067 "kind": "openai-compat",
1068 "name": "ollama",
1069 "base_url": "http://localhost:11434",
1070 "model": "llama3.1:8b"
1071 }
1072 ]
1073 }"#;
1074 let f = write_config(json);
1075 let cfg = ConfigFile::load(f.path()).unwrap();
1076 assert_eq!(cfg.resolved_backends().len(), 1);
1077 }
1078
1079 #[test]
1080 fn rejects_unknown_kind() {
1081 let json = r#"{
1082 "backends": [
1083 {
1084 "kind": "future-thing-not-supported",
1085 "name": "x"
1086 }
1087 ]
1088 }"#;
1089 let f = write_config(json);
1090 let err = ConfigFile::load(f.path()).unwrap_err();
1091 assert!(matches!(err, ConfigError::Parse { .. }));
1092 }
1093
1094 #[test]
1095 fn loads_bedrock_invoke_entry() {
1096 let json = r#"{
1097 "backends": [
1098 {
1099 "kind": "bedrock-invoke",
1100 "name": "bedrock-claude",
1101 "region": "us-east-1",
1102 "model_id": "anthropic.claude-3-5-sonnet-20241022-v2:0",
1103 "bearer_token_env": "AWS_BEARER_TOKEN_BEDROCK"
1104 }
1105 ]
1106 }"#;
1107 let f = write_config(json);
1108 let cfg = ConfigFile::load(f.path()).unwrap();
1109 let list = cfg.backends.as_ref().unwrap();
1110 assert_eq!(list.len(), 1);
1111 match &list[0] {
1112 BackendEntry::BedrockInvoke(e) => {
1113 assert_eq!(e.name, "bedrock-claude");
1114 assert_eq!(e.region, "us-east-1");
1115 assert_eq!(e.model_id, "anthropic.claude-3-5-sonnet-20241022-v2:0");
1116 assert_eq!(
1117 e.bearer_token_env.as_deref(),
1118 Some("AWS_BEARER_TOKEN_BEDROCK")
1119 );
1120 assert!(e.endpoint.is_none());
1121 assert_eq!(e.timeout_secs, 300);
1122 }
1123 other => panic!("expected bedrock-invoke, got {other:?}"),
1124 }
1125 }
1126
1127 #[test]
1128 fn rejects_bedrock_invoke_without_region() {
1129 let json = r#"{
1130 "backends": [
1131 {
1132 "kind": "bedrock-invoke",
1133 "name": "x",
1134 "region": "",
1135 "model_id": "anthropic.claude-3-5-sonnet-20241022-v2:0"
1136 }
1137 ]
1138 }"#;
1139 let f = write_config(json);
1140 let err = ConfigFile::load(f.path()).unwrap_err();
1141 match err {
1142 ConfigError::Invalid(msg) => assert!(msg.contains("region")),
1143 other => panic!("expected Invalid, got {other:?}"),
1144 }
1145 }
1146
1147 #[test]
1148 fn rejects_bedrock_invoke_without_model_id() {
1149 let json = r#"{
1150 "backends": [
1151 {
1152 "kind": "bedrock-invoke",
1153 "name": "x",
1154 "region": "us-east-1",
1155 "model_id": ""
1156 }
1157 ]
1158 }"#;
1159 let f = write_config(json);
1160 let err = ConfigFile::load(f.path()).unwrap_err();
1161 match err {
1162 ConfigError::Invalid(msg) => assert!(msg.contains("model_id")),
1163 other => panic!("expected Invalid, got {other:?}"),
1164 }
1165 }
1166
1167 #[test]
1168 fn legacy_model_promotes_to_one_backend() {
1169 let f = write_config(&good_json());
1170 let cfg = ConfigFile::load(f.path()).unwrap();
1171 let resolved = cfg.resolved_backends();
1172 assert_eq!(resolved.len(), 1);
1173 match &resolved[0] {
1174 BackendEntry::Llamacpp(e) => {
1175 assert_eq!(e.name, "gemma-4-e4b");
1176 assert_eq!(e.n_ctx, 8192);
1177 assert_eq!(e.n_gpu_layers, 0);
1178 }
1179 other => panic!("expected llamacpp, got {other:?}"),
1180 }
1181 }
1182
1183 #[test]
1184 fn multi_backend_resolved_passes_through() {
1185 let f = write_config(&good_multi_backend_json());
1186 let cfg = ConfigFile::load(f.path()).unwrap();
1187 let resolved = cfg.resolved_backends();
1188 assert_eq!(resolved.len(), 2);
1189 assert_eq!(resolved[0].name(), "local-gemma");
1190 assert_eq!(resolved[1].name(), "anthropic-fallback");
1191 }
1192
1193 #[test]
1194 fn listen_block_absent_by_default() {
1195 let f = write_config(&good_json());
1196 let cfg = ConfigFile::load(f.path()).unwrap();
1197 assert!(cfg.listen.is_none());
1198 }
1199
1200 #[test]
1201 fn listen_block_carries_tcp_and_api_key_env() {
1202 let json = r#"{
1203 "model": {
1204 "name": "gemma-4-e4b",
1205 "sha256": "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
1206 "source_url": "https://example.com/x.gguf"
1207 },
1208 "listen": {
1209 "tcp": "127.0.0.1:9090",
1210 "tcp_v2": "127.0.0.1:9091",
1211 "api_key_env": "INFERD_TCP_API_KEY"
1212 }
1213 }"#;
1214 let f = write_config(json);
1215 let cfg = ConfigFile::load(f.path()).unwrap();
1216 let listen = cfg.listen.as_ref().expect("listen present");
1217 assert_eq!(listen.tcp.as_deref(), Some("127.0.0.1:9090"));
1218 assert_eq!(listen.tcp_v2.as_deref(), Some("127.0.0.1:9091"));
1219 assert_eq!(listen.api_key_env.as_deref(), Some("INFERD_TCP_API_KEY"));
1220 }
1221
1222 #[test]
1223 fn llamacpp_entry_embed_defaults_off() {
1224 let f = write_config(&good_multi_backend_json());
1225 let cfg = ConfigFile::load(f.path()).unwrap();
1226 let list = cfg.backends.as_ref().unwrap();
1227 match &list[0] {
1228 BackendEntry::Llamacpp(e) => {
1229 assert!(!e.embed);
1230 assert!(e.embed_pooling.is_none());
1231 assert_eq!(e.embed_n_ctx, 2048);
1232 }
1233 other => panic!("expected llamacpp, got {other:?}"),
1234 }
1235 }
1236
1237 #[test]
1238 fn llamacpp_entry_carries_embed_fields() {
1239 let json = r#"{
1240 "backends": [
1241 {
1242 "kind": "llamacpp",
1243 "name": "embeddings",
1244 "embed": true,
1245 "embed_pooling": 1,
1246 "embed_n_ctx": 1024,
1247 "model": {
1248 "name": "embeddinggemma-300m",
1249 "sha256": "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
1250 "source_url": "https://example.com/embed.gguf"
1251 }
1252 }
1253 ]
1254 }"#;
1255 let f = write_config(json);
1256 let cfg = ConfigFile::load(f.path()).unwrap();
1257 let list = cfg.backends.as_ref().unwrap();
1258 match &list[0] {
1259 BackendEntry::Llamacpp(e) => {
1260 assert!(e.embed);
1261 assert_eq!(e.embed_pooling, Some(1));
1262 assert_eq!(e.embed_n_ctx, 1024);
1263 }
1264 other => panic!("expected llamacpp, got {other:?}"),
1265 }
1266 }
1267
1268 #[test]
1269 fn legacy_promotion_keeps_embed_off() {
1270 let f = write_config(&good_json());
1271 let cfg = ConfigFile::load(f.path()).unwrap();
1272 let list = cfg.resolved_backends();
1273 match &list[0] {
1274 BackendEntry::Llamacpp(e) => {
1275 assert!(!e.embed);
1276 assert!(e.embed_pooling.is_none());
1277 assert_eq!(e.embed_n_ctx, 2048);
1278 }
1279 other => panic!("expected llamacpp, got {other:?}"),
1280 }
1281 }
1282
1283 #[test]
1284 fn listen_rejects_empty_tcp() {
1285 let json = r#"{
1286 "model": {
1287 "name": "gemma-4-e4b",
1288 "sha256": "30d1e7949597a3446726064e80b876fd1b5cba4aa6eec53d27afa420e731fb36",
1289 "source_url": "https://example.com/x.gguf"
1290 },
1291 "listen": { "tcp": " " }
1292 }"#;
1293 let f = write_config(json);
1294 let err = ConfigFile::load(f.path()).unwrap_err();
1295 match err {
1296 ConfigError::Invalid(msg) => assert!(msg.contains("listen.tcp")),
1297 other => panic!("expected Invalid, got {other:?}"),
1298 }
1299 }
1300
1301 #[test]
1302 fn default_first_boot_config_has_generate_and_embed() {
1303 let cfg = default_first_boot_config();
1304 assert!(cfg.auto_pull, "first-boot default must auto-pull");
1305 let list = cfg.backends.as_ref().expect("backends present");
1306 assert_eq!(list.len(), 2, "default ships generate + embed");
1307
1308 let mut saw_generate = false;
1309 let mut saw_embed = false;
1310 for entry in list {
1311 if let BackendEntry::Llamacpp(e) = entry {
1312 if e.embed {
1313 saw_embed = true;
1314 assert_eq!(e.model.name, "embeddinggemma-300m");
1315 } else {
1316 saw_generate = true;
1317 assert_eq!(e.model.name, "gemma-4-e4b");
1318 }
1319 }
1320 }
1321 assert!(saw_generate, "default must include a generate backend");
1322 assert!(saw_embed, "default must include an embed backend");
1323 }
1324
1325 #[test]
1326 fn default_first_boot_config_validates() {
1327 let cfg = default_first_boot_config();
1331 let json = serde_json::to_string(&cfg).unwrap();
1332 let f = write_config(&json);
1333 ConfigFile::load(f.path()).expect("default config validates");
1334 }
1335
1336 #[test]
1337 fn write_default_if_missing_writes_when_absent() {
1338 let dir = tempfile::tempdir().unwrap();
1339 let path = dir.path().join("config.json");
1340 assert!(!path.exists());
1341
1342 let wrote = write_default_if_missing(&path).unwrap();
1343 assert!(wrote, "should report wrote=true on first call");
1344 assert!(path.exists(), "default config now on disk");
1345
1346 let cfg = ConfigFile::load(&path).unwrap();
1347 assert!(cfg.backends.is_some());
1348 assert!(cfg.auto_pull);
1349 }
1350
1351 #[test]
1352 fn write_default_if_missing_does_not_overwrite() {
1353 let dir = tempfile::tempdir().unwrap();
1356 let path = dir.path().join("config.json");
1357 std::fs::write(&path, "{ \"i_am_user_data\": true }").unwrap();
1358
1359 let wrote = write_default_if_missing(&path).unwrap();
1360 assert!(!wrote, "should report wrote=false when file exists");
1361
1362 let on_disk = std::fs::read_to_string(&path).unwrap();
1363 assert!(
1364 on_disk.contains("i_am_user_data"),
1365 "operator file preserved verbatim"
1366 );
1367 }
1368
1369 #[test]
1370 fn write_default_if_missing_creates_parent_dir() {
1371 let dir = tempfile::tempdir().unwrap();
1372 let path = dir.path().join("nested").join("subdir").join("config.json");
1373 assert!(!path.parent().unwrap().exists());
1374
1375 let wrote = write_default_if_missing(&path).unwrap();
1376 assert!(wrote);
1377 assert!(path.exists());
1378 }
1379}