#[cfg(feature = "schema")]
use crate::Confidence;
#[cfg(feature = "schema")]
use schemars::{schema_for, JsonSchema};
#[cfg(feature = "schema")]
use serde_json::Value;
#[cfg(feature = "schema")]
pub use schemars;
#[cfg(feature = "schema")]
pub fn generate_entity_schema() -> Value {
serde_json::to_value(schema_for!(schema_types::SchemaEntity))
.expect("Entity schema should be valid JSON")
}
#[cfg(feature = "schema")]
pub fn generate_entity_type_schema() -> Value {
serde_json::to_value(schema_for!(schema_types::SchemaEntityType))
.expect("EntityType schema should be valid JSON")
}
#[cfg(feature = "schema")]
pub fn generate_grounded_document_schema() -> Value {
serde_json::to_value(schema_for!(schema_types::SchemaGroundedDocument))
.expect("GroundedDocument schema should be valid JSON")
}
#[cfg(feature = "schema")]
pub fn generate_span_schema() -> Value {
serde_json::to_value(schema_for!(schema_types::SchemaSpan))
.expect("Span schema should be valid JSON")
}
#[cfg(feature = "schema")]
pub fn generate_relation_schema() -> Value {
serde_json::to_value(schema_for!(schema_types::SchemaRelation))
.expect("Relation schema should be valid JSON")
}
#[cfg(feature = "schema")]
pub fn generate_all_schemas() -> std::collections::HashMap<&'static str, Value> {
let mut schemas = std::collections::HashMap::new();
schemas.insert("entity", generate_entity_schema());
schemas.insert("entity_type", generate_entity_type_schema());
schemas.insert("span", generate_span_schema());
schemas.insert("relation", generate_relation_schema());
schemas.insert("grounded_document", generate_grounded_document_schema());
schemas
}
#[cfg(feature = "schema")]
mod schema_types {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SchemaEntityCategory {
Agent,
Organization,
Place,
Creative,
Temporal,
Numeric,
Contact,
Relation,
Misc,
}
#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SchemaEntityType {
Person,
Organization,
Location,
Date,
Time,
Money,
Percent,
Quantity,
Cardinal,
Ordinal,
Email,
Url,
Phone,
Custom {
name: String,
category: SchemaEntityCategory,
},
Other(String),
}
#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
pub struct SchemaSpan {
pub start: i64,
pub end: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub width: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub height: Option<f64>,
}
#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
pub struct SchemaEntity {
pub text: String,
pub entity_type: SchemaEntityType,
pub start: usize,
pub end: usize,
pub confidence: Confidence,
#[serde(skip_serializing_if = "Option::is_none")]
pub normalized: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub kb_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub canonical_id: Option<u64>,
}
#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
pub struct SchemaRelation {
pub source_idx: usize,
pub target_idx: usize,
pub relation_type: String,
pub confidence: Confidence,
#[serde(skip_serializing_if = "Option::is_none")]
pub trigger: Option<SchemaSpan>,
}
#[derive(Debug, Clone, JsonSchema, Serialize, Deserialize)]
pub struct SchemaGroundedDocument {
pub id: String,
pub text: String,
pub entities: Vec<SchemaEntity>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub relations: Vec<SchemaRelation>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
}
}
#[cfg(feature = "schema")]
pub use schema_types::*;
#[cfg(test)]
#[cfg(feature = "schema")]
mod tests {
use super::*;
#[test]
fn test_entity_schema_generation() {
let schema = generate_entity_schema();
assert!(schema.is_object());
let obj = schema.as_object().unwrap();
assert!(
obj.contains_key("$schema") || obj.contains_key("title") || obj.contains_key("type")
);
}
#[test]
fn test_all_schemas_generation() {
let schemas = generate_all_schemas();
assert!(schemas.contains_key("entity"));
assert!(schemas.contains_key("entity_type"));
assert!(schemas.contains_key("grounded_document"));
}
}