1use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine};
20use llmtrace_core::{SecurityFinding, SecuritySeverity};
21use rand::Rng;
22use std::collections::HashMap;
23use std::sync::{Arc, RwLock};
24use std::time::Instant;
25
26#[derive(Debug, Clone)]
35pub struct CanaryToken {
36 pub token: String,
38 pub created_at: Instant,
40 pub label: Option<String>,
42}
43
44#[derive(Debug, Clone)]
46pub struct CanaryConfig {
47 pub enabled: bool,
49 pub prefix: String,
51 pub token_length: usize,
53 pub detect_partial: bool,
55 pub partial_min_length: usize,
57}
58
59impl Default for CanaryConfig {
60 fn default() -> Self {
61 Self {
62 enabled: true,
63 prefix: "CANARY-".to_string(),
64 token_length: 16,
65 detect_partial: true,
66 partial_min_length: 8,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct CanaryDetection {
74 pub token: String,
76 pub match_type: CanaryMatchType,
78 pub confidence: f64,
80 pub position: usize,
82}
83
84#[derive(Debug, Clone, PartialEq)]
86pub enum CanaryMatchType {
87 Full,
89 Partial {
91 matched_length: usize,
93 },
94 Encoded {
96 encoding: String,
98 },
99}
100
101impl CanaryToken {
106 pub fn generate(config: &CanaryConfig) -> Self {
121 let mut rng = rand::thread_rng();
122 let random_part: String = (0..config.token_length)
123 .map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
124 .collect();
125
126 Self {
127 token: format!("{}{}", config.prefix, random_part),
128 created_at: Instant::now(),
129 label: None,
130 }
131 }
132
133 pub fn generate_with_label(config: &CanaryConfig, label: &str) -> Self {
148 let mut token = Self::generate(config);
149 token.label = Some(label.to_string());
150 token
151 }
152}
153
154pub fn inject_canary(prompt: &str, token: &CanaryToken) -> String {
182 format!("{}\n[SYSTEM_INTEGRITY_TOKEN: {}]\n", prompt, token.token)
183}
184
185pub fn detect_canary(
214 response: &str,
215 tokens: &[CanaryToken],
216 config: &CanaryConfig,
217) -> Vec<CanaryDetection> {
218 if !config.enabled {
219 return Vec::new();
220 }
221
222 let mut detections = Vec::new();
223 let response_lower = response.to_lowercase();
224
225 for canary in tokens {
226 let token_str = &canary.token;
227
228 if let Some(pos) = response.find(token_str) {
230 detections.push(CanaryDetection {
231 token: token_str.clone(),
232 match_type: CanaryMatchType::Full,
233 confidence: 1.0,
234 position: pos,
235 });
236 continue; }
238
239 let token_lower = token_str.to_lowercase();
241 if let Some(pos) = response_lower.find(&token_lower) {
242 detections.push(CanaryDetection {
243 token: token_str.clone(),
244 match_type: CanaryMatchType::Full,
245 confidence: 0.95,
246 position: pos,
247 });
248 continue;
249 }
250
251 let b64_encoded = BASE64_STANDARD.encode(token_str.as_bytes());
253 if let Some(pos) = response.find(&b64_encoded) {
254 detections.push(CanaryDetection {
255 token: token_str.clone(),
256 match_type: CanaryMatchType::Encoded {
257 encoding: "base64".to_string(),
258 },
259 confidence: 0.9,
260 position: pos,
261 });
262 continue;
263 }
264
265 let hex_encoded = hex_encode(token_str);
267 if let Some(pos) = response.to_lowercase().find(&hex_encoded.to_lowercase()) {
268 detections.push(CanaryDetection {
269 token: token_str.clone(),
270 match_type: CanaryMatchType::Encoded {
271 encoding: "hex".to_string(),
272 },
273 confidence: 0.85,
274 position: pos,
275 });
276 continue;
277 }
278
279 let reversed: String = token_str.chars().rev().collect();
281 if let Some(pos) = response.find(&reversed) {
282 detections.push(CanaryDetection {
283 token: token_str.clone(),
284 match_type: CanaryMatchType::Encoded {
285 encoding: "reversed".to_string(),
286 },
287 confidence: 0.85,
288 position: pos,
289 });
290 continue;
291 }
292
293 if config.detect_partial && token_str.len() >= config.partial_min_length {
295 if let Some(detection) = detect_partial_match(response, token_str, config) {
296 detections.push(detection);
297 }
298 }
299 }
300
301 detections
302}
303
304fn detect_partial_match(
306 response: &str,
307 token: &str,
308 config: &CanaryConfig,
309) -> Option<CanaryDetection> {
310 let min_len = config.partial_min_length;
312 if token.len() < min_len {
313 return None;
314 }
315
316 for window_size in (min_len..token.len()).rev() {
317 for start in 0..=(token.len() - window_size) {
318 let substr = &token[start..start + window_size];
319 if let Some(pos) = response.find(substr) {
320 let confidence = window_size as f64 / token.len() as f64;
321 return Some(CanaryDetection {
322 token: token.to_string(),
323 match_type: CanaryMatchType::Partial {
324 matched_length: window_size,
325 },
326 confidence,
327 position: pos,
328 });
329 }
330 }
331 }
332
333 None
334}
335
336fn hex_encode(s: &str) -> String {
338 s.as_bytes().iter().map(|b| format!("{b:02x}")).collect()
339}
340
341pub fn detect_canary_leakage(
369 response: &str,
370 tokens: &[CanaryToken],
371 config: &CanaryConfig,
372) -> Vec<SecurityFinding> {
373 detect_canary(response, tokens, config)
374 .into_iter()
375 .map(|detection| {
376 let severity = match &detection.match_type {
377 CanaryMatchType::Full => SecuritySeverity::Critical,
378 CanaryMatchType::Encoded { .. } => SecuritySeverity::High,
379 CanaryMatchType::Partial { .. } => SecuritySeverity::Medium,
380 };
381
382 let match_desc = match &detection.match_type {
383 CanaryMatchType::Full => "exact match".to_string(),
384 CanaryMatchType::Partial { matched_length } => {
385 format!("partial match ({matched_length} chars)")
386 }
387 CanaryMatchType::Encoded { encoding } => {
388 format!("encoded match ({encoding})")
389 }
390 };
391
392 SecurityFinding::new(
393 severity,
394 "canary_token_leakage".to_string(),
395 format!(
396 "System prompt leakage detected: canary token '{}' found via {} at position {} (confidence: {:.2})",
397 detection.token,
398 match_desc,
399 detection.position,
400 detection.confidence,
401 ),
402 detection.confidence,
403 )
404 .with_metadata("token".to_string(), detection.token)
405 .with_metadata("match_type".to_string(), format!("{:?}", detection.match_type))
406 .with_metadata("position".to_string(), detection.position.to_string())
407 .with_location("response.content".to_string())
408 })
409 .collect()
410}
411
412#[derive(Debug, Clone)]
438pub struct CanaryTokenStore {
439 inner: Arc<RwLock<HashMap<String, Vec<CanaryToken>>>>,
441}
442
443impl Default for CanaryTokenStore {
444 fn default() -> Self {
445 Self::new()
446 }
447}
448
449impl CanaryTokenStore {
450 pub fn new() -> Self {
452 Self {
453 inner: Arc::new(RwLock::new(HashMap::new())),
454 }
455 }
456
457 pub fn add(&self, tenant_id: &str, token: CanaryToken) {
459 let mut map = self.inner.write().expect("canary store lock poisoned");
460 map.entry(tenant_id.to_string()).or_default().push(token);
461 }
462
463 pub fn remove(&self, tenant_id: &str, token_str: &str) -> bool {
467 let mut map = self.inner.write().expect("canary store lock poisoned");
468 if let Some(tokens) = map.get_mut(tenant_id) {
469 let before = tokens.len();
470 tokens.retain(|t| t.token != token_str);
471 let removed = tokens.len() < before;
472 if tokens.is_empty() {
474 map.remove(tenant_id);
475 }
476 removed
477 } else {
478 false
479 }
480 }
481
482 pub fn get(&self, tenant_id: &str) -> Vec<CanaryToken> {
486 let map = self.inner.read().expect("canary store lock poisoned");
487 map.get(tenant_id).cloned().unwrap_or_default()
488 }
489
490 pub fn tenant_count(&self) -> usize {
492 let map = self.inner.read().expect("canary store lock poisoned");
493 map.len()
494 }
495
496 pub fn token_count(&self) -> usize {
498 let map = self.inner.read().expect("canary store lock poisoned");
499 map.values().map(|v| v.len()).sum()
500 }
501}
502
503#[cfg(test)]
508mod tests {
509 use super::*;
510
511 fn default_config() -> CanaryConfig {
512 CanaryConfig::default()
513 }
514
515 #[test]
518 fn test_generate_has_correct_prefix() {
519 let config = default_config();
520 let token = CanaryToken::generate(&config);
521 assert!(
522 token.token.starts_with("CANARY-"),
523 "token should start with default prefix"
524 );
525 }
526
527 #[test]
528 fn test_generate_has_correct_length() {
529 let config = default_config();
530 let token = CanaryToken::generate(&config);
531 let expected_len = config.prefix.len() + config.token_length;
533 assert_eq!(token.token.len(), expected_len);
534 }
535
536 #[test]
537 fn test_generate_tokens_are_unique() {
538 let config = default_config();
539 let t1 = CanaryToken::generate(&config);
540 let t2 = CanaryToken::generate(&config);
541 assert_ne!(t1.token, t2.token, "two generated tokens should differ");
542 }
543
544 #[test]
545 fn test_generate_with_label() {
546 let config = default_config();
547 let token = CanaryToken::generate_with_label(&config, "my-prompt");
548 assert_eq!(token.label.as_deref(), Some("my-prompt"));
549 assert!(token.token.starts_with("CANARY-"));
550 }
551
552 #[test]
553 fn test_generate_no_label_by_default() {
554 let config = default_config();
555 let token = CanaryToken::generate(&config);
556 assert!(token.label.is_none());
557 }
558
559 #[test]
560 fn test_custom_prefix_and_length() {
561 let config = CanaryConfig {
562 prefix: "TOK_".to_string(),
563 token_length: 32,
564 ..default_config()
565 };
566 let token = CanaryToken::generate(&config);
567 assert!(token.token.starts_with("TOK_"));
568 assert_eq!(token.token.len(), 4 + 32);
569 }
570
571 #[test]
574 fn test_detect_exact_match() {
575 let config = default_config();
576 let token = CanaryToken::generate(&config);
577 let response = format!("The system prompt is: {}", token.token);
578 let detections = detect_canary(&response, &[token], &config);
579
580 assert_eq!(detections.len(), 1);
581 assert_eq!(detections[0].match_type, CanaryMatchType::Full);
582 assert!((detections[0].confidence - 1.0).abs() < f64::EPSILON);
583 }
584
585 #[test]
586 fn test_detect_exact_match_position() {
587 let config = default_config();
588 let token = CanaryToken::generate(&config);
589 let prefix = "Leaked: ";
590 let response = format!("{}{}", prefix, token.token);
591 let detections = detect_canary(&response, &[token], &config);
592
593 assert_eq!(detections.len(), 1);
594 assert_eq!(detections[0].position, prefix.len());
595 }
596
597 #[test]
600 fn test_detect_case_insensitive() {
601 let config = default_config();
602 let token = CanaryToken::generate(&config);
603 let response = token.token.to_lowercase();
604 let detections = detect_canary(&response, &[token], &config);
606
607 assert_eq!(detections.len(), 1);
608 assert!(detections[0].confidence >= 0.95);
610 }
611
612 #[test]
613 fn test_detect_case_insensitive_upper() {
614 let config = default_config();
615 let token = CanaryToken::generate(&config);
616 let response = token.token.to_uppercase();
617 let detections = detect_canary(&response, &[token], &config);
618
619 assert_eq!(detections.len(), 1);
620 assert!(detections[0].confidence >= 0.95);
621 }
622
623 #[test]
626 fn test_detect_partial_match() {
627 let config = CanaryConfig {
628 detect_partial: true,
629 partial_min_length: 8,
630 ..default_config()
631 };
632 let token = CanaryToken::generate(&config);
633 let partial = &token.token[..10];
635 let response = format!("Some text with {partial} inside");
636 let detections = detect_canary(&response, &[token], &config);
637
638 assert_eq!(detections.len(), 1);
639 match &detections[0].match_type {
640 CanaryMatchType::Partial { matched_length } => {
641 assert!(*matched_length >= 10);
642 }
643 other => panic!("expected Partial, got {:?}", other),
644 }
645 }
646
647 #[test]
648 fn test_partial_match_respects_min_length() {
649 let config = CanaryConfig {
650 detect_partial: true,
651 partial_min_length: 20,
652 ..default_config()
653 };
654 let token = CanaryToken::generate(&config);
655 let partial = &token.token[..8];
657 let response = format!("Some text with {partial} inside");
658 let detections = detect_canary(&response, &[token], &config);
659
660 assert!(
661 detections.is_empty(),
662 "short substring should not trigger partial match"
663 );
664 }
665
666 #[test]
667 fn test_partial_match_disabled() {
668 let config = CanaryConfig {
669 detect_partial: false,
670 ..default_config()
671 };
672 let token = CanaryToken::generate(&config);
673 let partial = &token.token[..10];
674 let response = format!("Some text with {partial} inside");
675 let detections = detect_canary(&response, &[token], &config);
676
677 assert!(
678 detections.is_empty(),
679 "partial detection should be disabled"
680 );
681 }
682
683 #[test]
686 fn test_detect_base64_encoded() {
687 let config = CanaryConfig {
688 detect_partial: false,
689 ..default_config()
690 };
691 let token = CanaryToken::generate(&config);
692 let encoded = BASE64_STANDARD.encode(token.token.as_bytes());
693 let response = format!("Here is some data: {encoded}");
694 let detections = detect_canary(&response, &[token], &config);
695
696 assert_eq!(detections.len(), 1);
697 assert_eq!(
698 detections[0].match_type,
699 CanaryMatchType::Encoded {
700 encoding: "base64".to_string()
701 }
702 );
703 assert!((detections[0].confidence - 0.9).abs() < f64::EPSILON);
704 }
705
706 #[test]
709 fn test_detect_hex_encoded() {
710 let config = CanaryConfig {
711 detect_partial: false,
712 ..default_config()
713 };
714 let token = CanaryToken::generate(&config);
715 let hex = hex_encode(&token.token);
716 let response = format!("Hex dump: {hex}");
717 let detections = detect_canary(&response, &[token], &config);
718
719 assert_eq!(detections.len(), 1);
720 assert_eq!(
721 detections[0].match_type,
722 CanaryMatchType::Encoded {
723 encoding: "hex".to_string()
724 }
725 );
726 assert!((detections[0].confidence - 0.85).abs() < f64::EPSILON);
727 }
728
729 #[test]
732 fn test_detect_reversed_token() {
733 let config = CanaryConfig {
734 detect_partial: false,
735 ..default_config()
736 };
737 let token = CanaryToken::generate(&config);
738 let reversed: String = token.token.chars().rev().collect();
739 let response = format!("Reversed: {reversed}");
740 let detections = detect_canary(&response, &[token], &config);
741
742 assert_eq!(detections.len(), 1);
743 assert_eq!(
744 detections[0].match_type,
745 CanaryMatchType::Encoded {
746 encoding: "reversed".to_string()
747 }
748 );
749 }
750
751 #[test]
754 fn test_no_canary_no_detection() {
755 let config = default_config();
756 let token = CanaryToken::generate(&config);
757 let response = "This is a perfectly normal response with no tokens.";
758 let detections = detect_canary(response, &[token], &config);
759
760 assert!(detections.is_empty(), "should not produce false positives");
761 }
762
763 #[test]
764 fn test_no_canary_detection_disabled() {
765 let config = CanaryConfig {
766 enabled: false,
767 ..default_config()
768 };
769 let token = CanaryToken::generate(&config);
770 let response = format!("Leaked: {}", token.token);
771 let detections = detect_canary(&response, &[token], &config);
772
773 assert!(
774 detections.is_empty(),
775 "should return empty when disabled even if token present"
776 );
777 }
778
779 #[test]
780 fn test_no_false_positives_on_similar_text() {
781 let config = CanaryConfig {
782 detect_partial: false,
783 ..default_config()
784 };
785 let token = CanaryToken::generate(&config);
786 let response = "CANARY-something-else-entirely and more text";
787 let detections = detect_canary(response, &[token], &config);
788
789 assert!(
790 detections.is_empty(),
791 "different canary prefix text should not match"
792 );
793 }
794
795 #[test]
798 fn test_security_finding_full_match() {
799 let config = default_config();
800 let token = CanaryToken::generate(&config);
801 let response = format!("Leaked: {}", token.token);
802 let findings = detect_canary_leakage(&response, &[token], &config);
803
804 assert_eq!(findings.len(), 1);
805 assert_eq!(findings[0].finding_type, "canary_token_leakage");
806 assert_eq!(findings[0].severity, SecuritySeverity::Critical);
807 assert!(findings[0].description.contains("exact match"));
808 }
809
810 #[test]
811 fn test_security_finding_encoded_match() {
812 let config = CanaryConfig {
813 detect_partial: false,
814 ..default_config()
815 };
816 let token = CanaryToken::generate(&config);
817 let hex = hex_encode(&token.token);
818 let response = format!("Hex: {hex}");
819 let findings = detect_canary_leakage(&response, &[token], &config);
820
821 assert_eq!(findings.len(), 1);
822 assert_eq!(findings[0].severity, SecuritySeverity::High);
823 assert!(findings[0].description.contains("encoded match"));
824 }
825
826 #[test]
827 fn test_security_finding_partial_match() {
828 let config = CanaryConfig {
829 detect_partial: true,
830 partial_min_length: 8,
831 ..default_config()
832 };
833 let token = CanaryToken::generate(&config);
834 let partial = &token.token[..10];
835 let response = format!("Fragment: {partial}");
836 let findings = detect_canary_leakage(&response, &[token], &config);
837
838 assert_eq!(findings.len(), 1);
839 assert_eq!(findings[0].severity, SecuritySeverity::Medium);
840 assert!(findings[0].description.contains("partial match"));
841 }
842
843 #[test]
844 fn test_security_finding_metadata() {
845 let config = default_config();
846 let token = CanaryToken::generate(&config);
847 let token_str = token.token.clone();
848 let response = format!("Leak: {token_str}");
849 let findings = detect_canary_leakage(&response, &[token], &config);
850
851 assert_eq!(findings.len(), 1);
852 assert_eq!(
853 findings[0].metadata.get("token").map(String::as_str),
854 Some(token_str.as_str())
855 );
856 assert!(findings[0].metadata.contains_key("match_type"));
857 assert!(findings[0].metadata.contains_key("position"));
858 assert_eq!(findings[0].location.as_deref(), Some("response.content"));
859 }
860
861 #[test]
864 fn test_store_add_and_get() {
865 let store = CanaryTokenStore::new();
866 let config = default_config();
867 let token = CanaryToken::generate(&config);
868 let token_str = token.token.clone();
869
870 store.add("tenant-1", token);
871 let tokens = store.get("tenant-1");
872 assert_eq!(tokens.len(), 1);
873 assert_eq!(tokens[0].token, token_str);
874 }
875
876 #[test]
877 fn test_store_get_empty_tenant() {
878 let store = CanaryTokenStore::new();
879 let tokens = store.get("nonexistent");
880 assert!(tokens.is_empty());
881 }
882
883 #[test]
884 fn test_store_remove() {
885 let store = CanaryTokenStore::new();
886 let config = default_config();
887 let token = CanaryToken::generate(&config);
888 let token_str = token.token.clone();
889
890 store.add("tenant-1", token);
891 assert!(store.remove("tenant-1", &token_str));
892 assert!(store.get("tenant-1").is_empty());
893 }
894
895 #[test]
896 fn test_store_remove_nonexistent() {
897 let store = CanaryTokenStore::new();
898 assert!(!store.remove("tenant-1", "no-such-token"));
899 }
900
901 #[test]
902 fn test_store_multiple_tenants() {
903 let store = CanaryTokenStore::new();
904 let config = default_config();
905
906 store.add("tenant-a", CanaryToken::generate(&config));
907 store.add("tenant-a", CanaryToken::generate(&config));
908 store.add("tenant-b", CanaryToken::generate(&config));
909
910 assert_eq!(store.get("tenant-a").len(), 2);
911 assert_eq!(store.get("tenant-b").len(), 1);
912 assert_eq!(store.tenant_count(), 2);
913 assert_eq!(store.token_count(), 3);
914 }
915
916 #[test]
917 fn test_store_remove_cleans_up_empty_tenant() {
918 let store = CanaryTokenStore::new();
919 let config = default_config();
920 let token = CanaryToken::generate(&config);
921 let token_str = token.token.clone();
922
923 store.add("tenant-1", token);
924 store.remove("tenant-1", &token_str);
925 assert_eq!(store.tenant_count(), 0);
926 }
927
928 #[test]
929 fn test_store_thread_safety() {
930 use std::thread;
931
932 let store = CanaryTokenStore::new();
933 let config = default_config();
934
935 let handles: Vec<_> = (0..10)
936 .map(|i| {
937 let store = store.clone();
938 let config = config.clone();
939 thread::spawn(move || {
940 let tenant = format!("tenant-{i}");
941 let token = CanaryToken::generate(&config);
942 store.add(&tenant, token);
943 store.get(&tenant)
944 })
945 })
946 .collect();
947
948 for handle in handles {
949 let tokens = handle.join().expect("thread panicked");
950 assert!(!tokens.is_empty());
951 }
952
953 assert_eq!(store.tenant_count(), 10);
954 }
955
956 #[test]
959 fn test_inject_canary_format() {
960 let config = default_config();
961 let token = CanaryToken::generate(&config);
962 let prompt = "You are a helpful assistant.";
963 let result = inject_canary(prompt, &token);
964
965 assert!(result.starts_with(prompt));
966 assert!(result.contains(&format!("[SYSTEM_INTEGRITY_TOKEN: {}]", token.token)));
967 }
968
969 #[test]
970 fn test_inject_canary_preserves_original() {
971 let config = default_config();
972 let token = CanaryToken::generate(&config);
973 let prompt = "Original prompt text\nWith multiple lines.";
974 let result = inject_canary(prompt, &token);
975
976 assert!(result.starts_with(prompt));
977 }
978
979 #[test]
980 fn test_inject_and_detect_roundtrip() {
981 let config = default_config();
982 let token = CanaryToken::generate(&config);
983 let prompt = inject_canary("System prompt", &token);
984
985 let response = format!("My system prompt is: {prompt}");
987 let detections = detect_canary(&response, &[token], &config);
988
989 assert!(
990 !detections.is_empty(),
991 "should detect the canary in leaked prompt"
992 );
993 assert_eq!(detections[0].match_type, CanaryMatchType::Full);
994 }
995
996 #[test]
999 fn test_detect_multiple_tokens() {
1000 let config = default_config();
1001 let t1 = CanaryToken::generate(&config);
1002 let t2 = CanaryToken::generate(&config);
1003 let response = format!("First: {} Second: {}", t1.token, t2.token);
1004 let detections = detect_canary(&response, &[t1, t2], &config);
1005
1006 assert_eq!(detections.len(), 2);
1007 }
1008}