use std::collections::{HashMap, VecDeque};
use std::time::{SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use crate::profiler::patterns::detect_pattern;
use crate::profiler::schema_types::{
EndpointSchema, FieldSchema, FieldType, SchemaViolation, ValidationResult,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaLearnerConfig {
pub max_schemas: usize,
pub min_samples_for_validation: u32,
pub max_nesting_depth: usize,
pub max_fields_per_schema: usize,
pub string_length_tolerance: f64,
pub number_value_tolerance: f64,
pub required_field_threshold: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ConfigValidationError {
pub field: &'static str,
pub message: String,
}
impl std::fmt::Display for ConfigValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid {}: {}", self.field, self.message)
}
}
impl std::error::Error for ConfigValidationError {}
impl SchemaLearnerConfig {
pub fn validate(&self) -> Result<(), ConfigValidationError> {
if self.string_length_tolerance < 1.0 {
return Err(ConfigValidationError {
field: "string_length_tolerance",
message: format!(
"must be >= 1.0 to avoid rejecting baseline data (got {})",
self.string_length_tolerance
),
});
}
if self.number_value_tolerance < 1.0 {
return Err(ConfigValidationError {
field: "number_value_tolerance",
message: format!(
"must be >= 1.0 to avoid rejecting baseline data (got {})",
self.number_value_tolerance
),
});
}
if !(0.0..=1.0).contains(&self.required_field_threshold) {
return Err(ConfigValidationError {
field: "required_field_threshold",
message: format!(
"must be between 0.0 and 1.0 (got {})",
self.required_field_threshold
),
});
}
Ok(())
}
}
impl Default for SchemaLearnerConfig {
fn default() -> Self {
Self {
max_schemas: 5000,
min_samples_for_validation: 10,
max_nesting_depth: 10,
max_fields_per_schema: 100,
string_length_tolerance: 1.5,
number_value_tolerance: 1.5,
required_field_threshold: 0.9,
}
}
}
#[derive(Debug, Clone)]
struct LruEntry {
key: String,
generation: u64,
}
struct LruTracker {
queue: VecDeque<LruEntry>,
generations: HashMap<String, u64>,
next_generation: u64,
}
impl LruTracker {
fn new(capacity: usize) -> Self {
Self {
queue: VecDeque::with_capacity(capacity),
generations: HashMap::with_capacity(capacity),
next_generation: 0,
}
}
fn touch(&mut self, key: &str) -> bool {
let generation = self.next_generation;
self.next_generation = self.next_generation.wrapping_add(1);
let is_new = !self.generations.contains_key(key);
self.generations.insert(key.to_string(), generation);
self.queue.push_back(LruEntry {
key: key.to_string(),
generation,
});
is_new
}
#[allow(dead_code)]
fn remove(&mut self, key: &str) {
self.generations.remove(key);
}
fn evict_oldest(&mut self) -> Option<String> {
while let Some(entry) = self.queue.pop_front() {
if let Some(¤t_gen) = self.generations.get(&entry.key) {
if current_gen == entry.generation {
self.generations.remove(&entry.key);
return Some(entry.key);
}
}
}
None
}
#[allow(dead_code)]
fn len(&self) -> usize {
self.generations.len()
}
fn clear(&mut self) {
self.queue.clear();
self.generations.clear();
self.next_generation = 0;
}
}
pub struct SchemaLearner {
schemas: DashMap<String, EndpointSchema>,
lru: Mutex<LruTracker>,
config: SchemaLearnerConfig,
}
impl Default for SchemaLearner {
fn default() -> Self {
Self::new()
}
}
impl SchemaLearner {
pub fn new() -> Self {
Self::with_config(SchemaLearnerConfig::default())
}
pub fn with_config(config: SchemaLearnerConfig) -> Self {
Self {
schemas: DashMap::with_capacity(config.max_schemas),
lru: Mutex::new(LruTracker::new(config.max_schemas)),
config,
}
}
pub fn config(&self) -> &SchemaLearnerConfig {
&self.config
}
pub fn len(&self) -> usize {
self.schemas.len()
}
pub fn is_empty(&self) -> bool {
self.schemas.is_empty()
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
pub fn learn_from_request(&self, template: &str, request_body: &serde_json::Value) {
self.learn_internal(template, request_body, SchemaTarget::Request);
}
pub fn learn_from_response(&self, template: &str, response_body: &serde_json::Value) {
self.learn_internal(template, response_body, SchemaTarget::Response);
}
pub fn learn_from_pair(
&self,
template: &str,
request_body: Option<&serde_json::Value>,
response_body: Option<&serde_json::Value>,
) {
let now = Self::now_ms();
self.ensure_schema(template, now);
if let Some(req) = request_body {
if req.is_object() {
self.update_schema_fields(template, req, SchemaTarget::Request, "", 0);
}
}
if let Some(resp) = response_body {
if resp.is_object() {
self.update_schema_fields(template, resp, SchemaTarget::Response, "", 0);
}
}
if let Some(mut schema) = self.schemas.get_mut(template) {
schema.sample_count += 1;
schema.last_updated_ms = now;
}
}
fn learn_internal(&self, template: &str, body: &serde_json::Value, target: SchemaTarget) {
if !body.is_object() {
return;
}
let now = Self::now_ms();
self.ensure_schema(template, now);
self.update_schema_fields(template, body, target, "", 0);
if matches!(target, SchemaTarget::Request) {
if let Some(mut schema) = self.schemas.get_mut(template) {
schema.sample_count += 1;
schema.last_updated_ms = now;
}
}
}
fn ensure_schema(&self, template: &str, now: u64) {
if self.schemas.contains_key(template) {
let mut lru = self.lru.lock();
lru.touch(template);
return;
}
let mut lru = self.lru.lock();
if self.schemas.contains_key(template) {
lru.touch(template);
return;
}
if self.schemas.len() >= self.config.max_schemas {
if let Some(evict_key) = lru.evict_oldest() {
self.schemas.remove(&evict_key);
}
}
lru.touch(template);
self.schemas.insert(
template.to_string(),
EndpointSchema::new(template.to_string(), now),
);
}
fn update_schema_fields(
&self,
template: &str,
value: &serde_json::Value,
target: SchemaTarget,
prefix: &str,
depth: usize,
) {
if depth >= self.config.max_nesting_depth {
return;
}
let obj = match value.as_object() {
Some(o) => o,
None => return,
};
let mut nested_objects: Vec<(String, &serde_json::Value)> = Vec::new();
{
let mut schema_guard = match self.schemas.get_mut(template) {
Some(s) => s,
None => return,
};
let schema_map = match target {
SchemaTarget::Request => &mut schema_guard.request_schema,
SchemaTarget::Response => &mut schema_guard.response_schema,
};
for (key, val) in obj {
if schema_map.len() >= self.config.max_fields_per_schema {
break;
}
let field_name = if prefix.is_empty() {
key.clone()
} else {
format!("{}.{}", prefix, key)
};
let field_type = FieldType::from_json_value(val);
let field_schema = schema_map
.entry(field_name.clone())
.or_insert_with(|| FieldSchema::new(field_name.clone()));
field_schema.record_type(field_type);
match val {
serde_json::Value::String(s) => {
let pattern = detect_pattern(s);
field_schema.update_string_constraints(s.len() as u32, pattern);
}
serde_json::Value::Number(n) => {
if let Some(f) = n.as_f64() {
field_schema.update_number_constraints(f);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
let item_type = FieldType::from_json_value(item);
field_schema.add_array_item_type(item_type);
}
}
serde_json::Value::Object(_) => {
if field_schema.object_schema.is_none() {
field_schema.object_schema = Some(HashMap::new());
}
nested_objects.push((field_name, val));
}
_ => {}
}
}
}
for (field_name, val) in nested_objects {
self.update_schema_fields(template, val, target, &field_name, depth + 1);
}
}
pub fn validate_request(
&self,
template: &str,
request_body: &serde_json::Value,
) -> ValidationResult {
self.validate_internal(template, request_body, SchemaTarget::Request)
}
pub fn validate_response(
&self,
template: &str,
response_body: &serde_json::Value,
) -> ValidationResult {
self.validate_internal(template, response_body, SchemaTarget::Response)
}
fn validate_internal(
&self,
template: &str,
body: &serde_json::Value,
target: SchemaTarget,
) -> ValidationResult {
let mut result = ValidationResult::new();
let schema = match self.schemas.get(template) {
Some(s) => s,
None => return result, };
if schema.sample_count < self.config.min_samples_for_validation {
return result;
}
let schema_map = match target {
SchemaTarget::Request => &schema.request_schema,
SchemaTarget::Response => &schema.response_schema,
};
self.validate_against_schema(
schema_map,
body,
"",
&mut result,
schema.sample_count,
0, );
result
}
fn validate_against_schema(
&self,
root_schema_map: &HashMap<String, FieldSchema>,
data: &serde_json::Value,
prefix: &str,
result: &mut ValidationResult,
sample_count: u32,
depth: usize,
) {
if depth >= self.config.max_nesting_depth {
return;
}
let obj = match data.as_object() {
Some(o) => o,
None => return,
};
for (key, val) in obj {
let field_name = if prefix.is_empty() {
key.clone()
} else {
format!("{}.{}", prefix, key)
};
let field_schema = match root_schema_map.get(&field_name) {
Some(s) => s,
None => {
result.add(SchemaViolation::unexpected_field(&field_name));
continue;
}
};
let actual_type = FieldType::from_json_value(val);
let dominant_type = field_schema.dominant_type();
if actual_type != dominant_type && !(val.is_null() && field_schema.nullable) {
result.add(SchemaViolation::type_mismatch(
&field_name,
dominant_type,
actual_type,
));
}
if let serde_json::Value::String(s) = val {
self.validate_string_field(&field_name, s, field_schema, result);
}
if let serde_json::Value::Number(n) = val {
if let Some(f) = n.as_f64() {
self.validate_number_field(&field_name, f, field_schema, result);
}
}
if val.is_object() {
self.validate_against_schema(
root_schema_map,
val,
&field_name,
result,
sample_count,
depth + 1,
);
}
}
let threshold = (sample_count as f64 * self.config.required_field_threshold) as u32;
for (field_name, field_schema) in root_schema_map {
let is_direct_child = if prefix.is_empty() {
!field_name.contains('.')
} else if field_name.starts_with(prefix) && field_name.len() > prefix.len() + 1 {
let suffix = &field_name[prefix.len() + 1..];
!suffix.contains('.')
} else {
false
};
if is_direct_child && field_schema.seen_count >= threshold {
let key = field_name.rsplit('.').next().unwrap_or(field_name);
if !obj.contains_key(key) {
result.add(SchemaViolation::missing_field(field_name));
}
}
}
}
fn validate_string_field(
&self,
field_name: &str,
value: &str,
schema: &FieldSchema,
result: &mut ValidationResult,
) {
let len = value.len() as u32;
if let Some(min) = schema.min_length {
if len < min {
result.add(SchemaViolation::string_too_short(field_name, min, len));
}
}
if let Some(max) = schema.max_length {
let allowed_max = (max as f64 * self.config.string_length_tolerance) as u32;
if len > allowed_max {
result.add(SchemaViolation::string_too_long(
field_name,
allowed_max,
len,
));
}
}
if let Some(expected_pattern) = schema.pattern {
let actual_pattern = detect_pattern(value);
if actual_pattern != Some(expected_pattern) {
result.add(SchemaViolation::pattern_mismatch(
field_name,
expected_pattern,
actual_pattern,
));
}
}
}
fn validate_number_field(
&self,
field_name: &str,
value: f64,
schema: &FieldSchema,
result: &mut ValidationResult,
) {
if let Some(min) = schema.min_value {
let allowed_min = min * (1.0 / self.config.number_value_tolerance);
if value < allowed_min {
result.add(SchemaViolation::number_too_small(
field_name,
allowed_min,
value,
));
}
}
if let Some(max) = schema.max_value {
let allowed_max = max * self.config.number_value_tolerance;
if value > allowed_max {
result.add(SchemaViolation::number_too_large(
field_name,
allowed_max,
value,
));
}
}
}
pub fn get_schema(&self, template: &str) -> Option<EndpointSchema> {
self.schemas.get(template).map(|s| s.value().clone())
}
pub fn get_all_schemas(&self) -> Vec<EndpointSchema> {
self.schemas
.iter()
.map(|entry| entry.value().clone())
.collect()
}
pub fn get_stats(&self) -> SchemaLearnerStats {
let schemas: Vec<_> = self.schemas.iter().collect();
let total_samples: u32 = schemas.iter().map(|s| s.sample_count).sum();
let total_fields: usize = schemas
.iter()
.map(|s| s.request_schema.len() + s.response_schema.len())
.sum();
SchemaLearnerStats {
total_schemas: schemas.len(),
total_samples,
avg_fields_per_endpoint: if schemas.is_empty() {
0.0
} else {
total_fields as f64 / schemas.len() as f64
},
}
}
pub fn export(&self) -> Vec<EndpointSchema> {
self.get_all_schemas()
}
pub fn import(&self, schemas: Vec<EndpointSchema>) {
self.schemas.clear();
let mut lru = self.lru.lock();
lru.clear();
let mut sorted_schemas = schemas;
sorted_schemas.sort_by_key(|s| s.last_updated_ms);
for schema in sorted_schemas {
lru.touch(&schema.template);
self.schemas.insert(schema.template.clone(), schema);
}
}
pub fn clear(&self) {
self.schemas.clear();
self.lru.lock().clear();
}
}
#[derive(Debug, Clone, Copy)]
enum SchemaTarget {
Request,
Response,
}
#[derive(Debug, Clone, Serialize)]
pub struct SchemaLearnerStats {
pub total_schemas: usize,
pub total_samples: u32,
pub avg_fields_per_endpoint: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::profiler::schema_types::{PatternType, ViolationType};
use serde_json::json;
#[test]
fn test_learn_from_request() {
let learner = SchemaLearner::new();
let body = json!({
"username": "john_doe",
"email": "john@example.com",
"age": 30
});
learner.learn_from_request("/api/users", &body);
let schema = learner.get_schema("/api/users").unwrap();
assert_eq!(schema.sample_count, 1);
assert!(schema.request_schema.contains_key("username"));
assert!(schema.request_schema.contains_key("email"));
assert!(schema.request_schema.contains_key("age"));
}
#[test]
fn test_learn_type_tracking() {
let learner = SchemaLearner::new();
for i in 0..10 {
let body = json!({
"id": i,
"name": format!("user_{}", i)
});
learner.learn_from_request("/api/users", &body);
}
let schema = learner.get_schema("/api/users").unwrap();
let id_schema = schema.request_schema.get("id").unwrap();
let name_schema = schema.request_schema.get("name").unwrap();
assert_eq!(id_schema.dominant_type(), FieldType::Number);
assert_eq!(name_schema.dominant_type(), FieldType::String);
assert_eq!(id_schema.seen_count, 10);
}
#[test]
fn test_learn_string_constraints() {
let learner = SchemaLearner::new();
let bodies = vec![
json!({"name": "ab"}), json!({"name": "abcdef"}), json!({"name": "abcd"}), ];
for body in bodies {
learner.learn_from_request("/api/test", &body);
}
let schema = learner.get_schema("/api/test").unwrap();
let name_schema = schema.request_schema.get("name").unwrap();
assert_eq!(name_schema.min_length, Some(2));
assert_eq!(name_schema.max_length, Some(6));
}
#[test]
fn test_learn_pattern_detection() {
let learner = SchemaLearner::new();
let body = json!({
"id": "550e8400-e29b-41d4-a716-446655440000",
"email": "user@example.com"
});
learner.learn_from_request("/api/users", &body);
let schema = learner.get_schema("/api/users").unwrap();
let id_schema = schema.request_schema.get("id").unwrap();
let email_schema = schema.request_schema.get("email").unwrap();
assert_eq!(id_schema.pattern, Some(PatternType::Uuid));
assert_eq!(email_schema.pattern, Some(PatternType::Email));
}
#[test]
fn test_learn_nested_objects() {
let learner = SchemaLearner::new();
let body = json!({
"user": {
"name": "John",
"address": {
"city": "NYC"
}
}
});
learner.learn_from_request("/api/data", &body);
let schema = learner.get_schema("/api/data").unwrap();
assert!(schema.request_schema.contains_key("user"));
assert!(schema.request_schema.contains_key("user.name"));
assert!(schema.request_schema.contains_key("user.address"));
assert!(schema.request_schema.contains_key("user.address.city"));
}
#[test]
fn test_validate_unexpected_field() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 5,
..Default::default()
});
for _ in 0..10 {
learner.learn_from_request("/api/users", &json!({"name": "test"}));
}
let result =
learner.validate_request("/api/users", &json!({"name": "test", "malicious": "value"}));
assert!(!result.is_valid());
assert!(result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::UnexpectedField));
}
#[test]
fn test_validate_type_mismatch() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 5,
..Default::default()
});
for i in 0..10 {
learner.learn_from_request("/api/users", &json!({"id": i}));
}
let result = learner.validate_request("/api/users", &json!({"id": "not_a_number"}));
assert!(!result.is_valid());
assert!(result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::TypeMismatch));
}
#[test]
fn test_validate_string_too_long() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 5,
string_length_tolerance: 2.0,
..Default::default()
});
for _ in 0..10 {
learner.learn_from_request("/api/users", &json!({"name": "john"})); }
let long_name = "a".repeat(20);
let result = learner.validate_request("/api/users", &json!({"name": long_name}));
assert!(!result.is_valid());
assert!(result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::StringTooLong));
}
#[test]
fn test_validate_pattern_mismatch() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 5,
..Default::default()
});
for _ in 0..10 {
learner.learn_from_request(
"/api/users",
&json!({"id": "550e8400-e29b-41d4-a716-446655440000"}),
);
}
let result = learner.validate_request("/api/users", &json!({"id": "not-a-uuid-value"}));
assert!(!result.is_valid());
assert!(result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::PatternMismatch));
}
#[test]
fn test_validate_insufficient_samples() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 10,
..Default::default()
});
for _ in 0..5 {
learner.learn_from_request("/api/users", &json!({"name": "test"}));
}
let result = learner.validate_request("/api/users", &json!({"malicious": "field"}));
assert!(result.is_valid());
}
#[test]
fn test_lru_eviction() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
max_schemas: 3,
..Default::default()
});
learner.learn_from_request("/api/users", &json!({"a": 1}));
std::thread::sleep(std::time::Duration::from_millis(10));
learner.learn_from_request("/api/orders", &json!({"b": 2}));
std::thread::sleep(std::time::Duration::from_millis(10));
learner.learn_from_request("/api/products", &json!({"c": 3}));
std::thread::sleep(std::time::Duration::from_millis(10));
learner.learn_from_request("/api/inventory", &json!({"d": 4}));
assert_eq!(learner.len(), 3);
assert!(learner.get_schema("/api/users").is_none());
assert!(learner.get_schema("/api/orders").is_some());
}
#[test]
fn test_stats() {
let learner = SchemaLearner::new();
for i in 0..10 {
learner.learn_from_request("/api/users", &json!({"id": i, "name": "test"}));
}
for i in 0..5 {
learner.learn_from_request("/api/orders", &json!({"order_id": i}));
}
let stats = learner.get_stats();
assert_eq!(stats.total_schemas, 2);
assert_eq!(stats.total_samples, 15);
assert!(stats.avg_fields_per_endpoint > 0.0);
}
#[test]
fn test_export_import() {
let learner = SchemaLearner::new();
learner.learn_from_request("/api/users", &json!({"id": 1, "name": "test"}));
learner.learn_from_request("/api/orders", &json!({"order_id": 100}));
let exported = learner.export();
assert_eq!(exported.len(), 2);
let learner2 = SchemaLearner::new();
learner2.import(exported);
assert_eq!(learner2.len(), 2);
assert!(learner2.get_schema("/api/users").is_some());
assert!(learner2.get_schema("/api/orders").is_some());
}
#[test]
fn test_nullable_fields() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 5,
..Default::default()
});
for i in 0..10 {
let body = if i % 2 == 0 {
json!({"name": "test"})
} else {
json!({"name": null})
};
learner.learn_from_request("/api/users", &body);
}
let schema = learner.get_schema("/api/users").unwrap();
let name_schema = schema.request_schema.get("name").unwrap();
assert!(name_schema.nullable);
let result = learner.validate_request("/api/users", &json!({"name": null}));
assert!(!result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::TypeMismatch && v.field == "name"));
}
#[test]
fn test_array_item_types() {
let learner = SchemaLearner::new();
let body = json!({
"tags": ["tag1", "tag2"],
"numbers": [1, 2, 3]
});
learner.learn_from_request("/api/items", &body);
let schema = learner.get_schema("/api/items").unwrap();
let tags_schema = schema.request_schema.get("tags").unwrap();
let numbers_schema = schema.request_schema.get("numbers").unwrap();
assert!(tags_schema
.array_item_types
.as_ref()
.unwrap()
.contains(&FieldType::String));
assert!(numbers_schema
.array_item_types
.as_ref()
.unwrap()
.contains(&FieldType::Number));
}
#[test]
fn test_validate_missing_required_field() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 5,
required_field_threshold: 0.9,
..Default::default()
});
for i in 0..10 {
learner.learn_from_request("/api/users", &json!({"id": i, "name": "test"}));
}
let result = learner.validate_request("/api/users", &json!({"id": 1}));
assert!(!result.is_valid());
assert!(result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::MissingField && v.field == "name"));
}
#[test]
fn test_validate_number_constraints() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
min_samples_for_validation: 5,
number_value_tolerance: 2.0,
..Default::default()
});
for i in 0..10 {
learner.learn_from_request("/api/items", &json!({"price": 10 + i * 10}));
}
let result = learner.validate_request("/api/items", &json!({"price": 500}));
assert!(!result.is_valid());
assert!(result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::NumberTooLarge));
let result = learner.validate_request("/api/items", &json!({"price": 1}));
assert!(!result.is_valid());
assert!(result
.violations
.iter()
.any(|v| v.violation_type == ViolationType::NumberTooSmall));
}
#[test]
fn test_validate_deeply_nested_json_does_not_stack_overflow() {
let learner = SchemaLearner::with_config(SchemaLearnerConfig {
max_nesting_depth: 10,
min_samples_for_validation: 0, ..Default::default()
});
let mut body = json!({"leaf": true});
for i in 0..100 {
body = json!({ format!("nest_{}", i): body });
}
learner.learn_from_request("/api/nested", &body);
let result = learner.validate_request("/api/nested", &body);
assert!(result.is_valid());
}
#[test]
fn test_learn_array_root_body_is_silently_skipped() {
let learner = SchemaLearner::new();
let body = json!([{"id": 1}, {"id": 2}]);
learner.learn_from_request("/api/arrays", &body);
assert_eq!(learner.len(), 0);
}
#[test]
fn test_learn_from_response_does_not_increment_sample_count() {
let learner = SchemaLearner::new();
learner.learn_from_response("/api/test", &json!({"ok": true}));
let schema = learner.get_schema("/api/test").unwrap();
assert_eq!(schema.sample_count, 0);
assert!(schema.response_schema.contains_key("ok"));
learner.learn_from_request("/api/test", &json!({"id": 1}));
let schema = learner.get_schema("/api/test").unwrap();
assert_eq!(schema.sample_count, 1);
}
#[test]
fn test_learn_from_pair_both_none() {
let learner = SchemaLearner::new();
learner.learn_from_pair("/api/empty", None, None);
let schema = learner.get_schema("/api/empty").unwrap();
assert_eq!(schema.sample_count, 1);
assert!(schema.request_schema.is_empty());
assert!(schema.response_schema.is_empty());
}
}