use anyhow::{anyhow, Result};
use scirs2_core::ndarray_ext::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NaturalLanguageQuery {
pub text: String,
pub context: Option<String>,
pub confidence_threshold: f32,
}
impl NaturalLanguageQuery {
pub fn new(text: String) -> Self {
Self {
text,
context: None,
confidence_threshold: 0.7,
}
}
pub fn with_context(mut self, context: String) -> Self {
self.context = Some(context);
self
}
pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
self.confidence_threshold = threshold;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedQuery {
pub query: String,
pub confidence: f32,
pub alternatives: Vec<AlternativeQuery>,
pub metadata: QueryMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AlternativeQuery {
pub query: String,
pub confidence: f32,
pub explanation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryMetadata {
pub intent: String,
pub entities: HashMap<String, String>,
pub operations: Vec<String>,
pub suggested_fields: Vec<String>,
}
impl Default for QueryMetadata {
fn default() -> Self {
Self {
intent: "unknown".to_string(),
entities: HashMap::new(),
operations: Vec::new(),
suggested_fields: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaInfo {
pub types: Vec<String>,
pub fields: HashMap<String, Vec<String>>,
pub descriptions: HashMap<String, String>,
}
impl SchemaInfo {
pub fn new() -> Self {
Self {
types: Vec::new(),
fields: HashMap::new(),
descriptions: HashMap::new(),
}
}
pub fn add_type(&mut self, type_name: String, fields: Vec<String>) {
self.types.push(type_name.clone());
self.fields.insert(type_name, fields);
}
pub fn add_description(&mut self, field: String, description: String) {
self.descriptions.insert(field, description);
}
}
impl Default for SchemaInfo {
fn default() -> Self {
Self::new()
}
}
pub struct NaturalLanguageQueryGenerator {
schema: Arc<RwLock<SchemaInfo>>,
intent_classifier: Arc<RwLock<IntentClassifier>>,
entity_extractor: Arc<RwLock<EntityExtractor>>,
templates: Arc<RwLock<Vec<QueryTemplate>>>,
}
#[derive(Debug, Clone)]
pub struct IntentClassifier {
intents: HashMap<String, Array1<f32>>,
}
impl IntentClassifier {
pub fn new() -> Self {
let mut classifier = Self {
intents: HashMap::new(),
};
classifier.initialize_intents();
classifier
}
fn initialize_intents(&mut self) {
let embedding_dim = 128;
let intents = vec![
"search",
"filter",
"aggregate",
"count",
"list",
"get",
"find",
"sort",
"group",
"update",
"delete",
"create",
];
for intent in intents {
let embedding = Array1::from_vec(
(0..embedding_dim)
.map(|i| ((i as f32 * 0.1) % 2.0) - 1.0)
.collect(),
);
self.intents.insert(intent.to_string(), embedding);
}
}
pub fn classify(&self, text: &str) -> (String, f32) {
let text_lower = text.to_lowercase();
let mut best_intent = "search".to_string();
let mut best_score = 0.5;
for intent in self.intents.keys() {
let score = if text_lower.contains(intent) {
0.9
} else {
0.3
};
if score > best_score {
best_score = score;
best_intent = intent.clone();
}
}
(best_intent, best_score)
}
}
impl Default for IntentClassifier {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct EntityExtractor {
#[allow(dead_code)]
patterns: HashMap<String, Vec<String>>,
}
impl EntityExtractor {
pub fn new() -> Self {
Self {
patterns: HashMap::new(),
}
}
pub fn extract(&self, text: &str, schema: &SchemaInfo) -> HashMap<String, String> {
let mut entities = HashMap::new();
let text_lower = text.to_lowercase();
for type_name in &schema.types {
if let Some(fields) = schema.fields.get(type_name) {
for field in fields {
if text_lower.contains(&field.to_lowercase()) {
entities.insert(field.clone(), type_name.clone());
}
}
}
}
let words: Vec<&str> = text.split_whitespace().collect();
for i in 0..words.len().saturating_sub(2) {
if words[i + 1] == "is" || words[i + 1] == "equals" {
entities.insert(
words[i].to_string(),
words[i + 2].trim_matches('"').to_string(),
);
}
}
entities
}
}
impl Default for EntityExtractor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryTemplate {
pub name: String,
pub intent: String,
pub template: String,
pub required_entities: Vec<String>,
}
impl NaturalLanguageQueryGenerator {
pub fn new() -> Self {
Self {
schema: Arc::new(RwLock::new(SchemaInfo::new())),
intent_classifier: Arc::new(RwLock::new(IntentClassifier::new())),
entity_extractor: Arc::new(RwLock::new(EntityExtractor::new())),
templates: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn register_schema(&self, schema: SchemaInfo) -> Result<()> {
let mut schema_guard = self.schema.write().await;
*schema_guard = schema;
Ok(())
}
pub async fn add_template(&self, template: QueryTemplate) -> Result<()> {
let mut templates = self.templates.write().await;
templates.push(template);
Ok(())
}
pub async fn generate(&self, nl_query: NaturalLanguageQuery) -> Result<GeneratedQuery> {
let intent_classifier = self.intent_classifier.read().await;
let (intent, intent_confidence) = intent_classifier.classify(&nl_query.text);
if intent_confidence < nl_query.confidence_threshold {
return Err(anyhow!(
"Low confidence in intent classification: {}",
intent_confidence
));
}
let entity_extractor = self.entity_extractor.read().await;
let schema = self.schema.read().await;
let entities = entity_extractor.extract(&nl_query.text, &schema);
let templates = self.templates.read().await;
let matching_template = templates
.iter()
.find(|t| t.intent == intent)
.ok_or_else(|| anyhow!("No template found for intent: {}", intent))?;
let query = self.fill_template(matching_template, &entities).await?;
let alternatives = self.generate_alternatives(&intent, &entities).await?;
let suggested_fields: Vec<String> = entities.keys().cloned().collect();
let metadata = QueryMetadata {
intent: intent.clone(),
entities,
operations: vec![intent.clone()],
suggested_fields,
};
Ok(GeneratedQuery {
query,
confidence: intent_confidence,
alternatives,
metadata,
})
}
async fn fill_template(
&self,
template: &QueryTemplate,
entities: &HashMap<String, String>,
) -> Result<String> {
let mut query = template.template.clone();
for (key, value) in entities {
let placeholder = format!("{{{}}}", key);
query = query.replace(&placeholder, value);
}
Ok(query)
}
async fn generate_alternatives(
&self,
intent: &str,
entities: &HashMap<String, String>,
) -> Result<Vec<AlternativeQuery>> {
let mut alternatives = Vec::new();
if !entities.is_empty() {
let alt_query = format!(
"query {{ {}(filter: {}) {{ id }} }}",
intent,
self.format_filter(entities)
);
alternatives.push(AlternativeQuery {
query: alt_query,
confidence: 0.6,
explanation: "Alternative with minimal field selection".to_string(),
});
}
Ok(alternatives)
}
fn format_filter(&self, entities: &HashMap<String, String>) -> String {
let filters: Vec<String> = entities
.iter()
.map(|(k, v)| format!("{}: \"{}\"", k, v))
.collect();
format!("{{ {} }}", filters.join(", "))
}
pub async fn get_schema(&self) -> SchemaInfo {
let schema = self.schema.read().await;
schema.clone()
}
pub async fn train_intent_classifier(&self, _examples: Vec<(String, String)>) -> Result<()> {
Ok(())
}
}
impl Default for NaturalLanguageQueryGenerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_natural_language_query_creation() {
let query = NaturalLanguageQuery::new("find all users".to_string());
assert_eq!(query.text, "find all users");
assert_eq!(query.confidence_threshold, 0.7);
}
#[test]
fn test_natural_language_query_with_context() {
let query = NaturalLanguageQuery::new("get items".to_string())
.with_context("e-commerce".to_string());
assert_eq!(query.context, Some("e-commerce".to_string()));
}
#[test]
fn test_schema_info_creation() {
let mut schema = SchemaInfo::new();
schema.add_type(
"User".to_string(),
vec!["id".to_string(), "name".to_string()],
);
assert_eq!(schema.types.len(), 1);
assert_eq!(schema.fields.get("User").expect("should succeed").len(), 2);
}
#[test]
fn test_intent_classifier() {
let classifier = IntentClassifier::new();
let (intent, confidence) = classifier.classify("search for users");
assert_eq!(intent, "search");
assert!(confidence > 0.5);
}
#[test]
fn test_intent_classifier_filter() {
let classifier = IntentClassifier::new();
let (intent, _) = classifier.classify("filter by name");
assert_eq!(intent, "filter");
}
#[test]
fn test_entity_extractor() {
let extractor = EntityExtractor::new();
let mut schema = SchemaInfo::new();
schema.add_type(
"User".to_string(),
vec!["name".to_string(), "email".to_string()],
);
let entities = extractor.extract("get user name", &schema);
assert!(entities.contains_key("name"));
}
#[tokio::test]
async fn test_generator_creation() {
let generator = NaturalLanguageQueryGenerator::new();
let schema = generator.get_schema().await;
assert_eq!(schema.types.len(), 0);
}
#[tokio::test]
async fn test_register_schema() {
let generator = NaturalLanguageQueryGenerator::new();
let mut schema = SchemaInfo::new();
schema.add_type("User".to_string(), vec!["id".to_string()]);
generator
.register_schema(schema)
.await
.expect("should succeed");
let registered = generator.get_schema().await;
assert_eq!(registered.types.len(), 1);
}
#[tokio::test]
async fn test_add_template() {
let generator = NaturalLanguageQueryGenerator::new();
let template = QueryTemplate {
name: "search_users".to_string(),
intent: "search".to_string(),
template: "query { users { id name } }".to_string(),
required_entities: vec![],
};
generator
.add_template(template)
.await
.expect("should succeed");
}
#[tokio::test]
async fn test_generate_query() {
let generator = NaturalLanguageQueryGenerator::new();
let mut schema = SchemaInfo::new();
schema.add_type(
"User".to_string(),
vec!["id".to_string(), "name".to_string()],
);
generator
.register_schema(schema)
.await
.expect("should succeed");
let template = QueryTemplate {
name: "search_users".to_string(),
intent: "search".to_string(),
template: "query { users { id name } }".to_string(),
required_entities: vec![],
};
generator
.add_template(template)
.await
.expect("should succeed");
let nl_query = NaturalLanguageQuery::new("search for users".to_string());
let result = generator.generate(nl_query).await;
assert!(result.is_ok());
let generated = result.expect("should succeed");
assert!(generated.query.contains("users"));
assert_eq!(generated.metadata.intent, "search");
}
#[tokio::test]
async fn test_generate_query_low_confidence() {
let generator = NaturalLanguageQueryGenerator::new();
let nl_query =
NaturalLanguageQuery::new("xyzabc".to_string()).with_confidence_threshold(0.95);
let result = generator.generate(nl_query).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_format_filter() {
let generator = NaturalLanguageQueryGenerator::new();
let mut entities = HashMap::new();
entities.insert("name".to_string(), "John".to_string());
entities.insert("age".to_string(), "30".to_string());
let filter = generator.format_filter(&entities);
assert!(filter.contains("name"));
assert!(filter.contains("John"));
}
#[tokio::test]
async fn test_train_intent_classifier() {
let generator = NaturalLanguageQueryGenerator::new();
let examples = vec![
("find all users".to_string(), "search".to_string()),
("filter by name".to_string(), "filter".to_string()),
];
let result = generator.train_intent_classifier(examples).await;
assert!(result.is_ok());
}
#[test]
fn test_query_metadata_default() {
let metadata = QueryMetadata::default();
assert_eq!(metadata.intent, "unknown");
assert!(metadata.entities.is_empty());
}
#[test]
fn test_alternative_query() {
let alt = AlternativeQuery {
query: "query { users { id } }".to_string(),
confidence: 0.7,
explanation: "test".to_string(),
};
assert_eq!(alt.confidence, 0.7);
}
}