use crate::{
EmbeddingError, EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use chrono::Utc;
use scirs2_core::ndarray_ext::{Array1, Array2};
#[allow(unused_imports)]
use scirs2_core::random::{Random, RngExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum GNNType {
GCN,
GraphSAGE,
GAT,
GraphTransformer,
GIN,
PNA,
HetGNN,
TGN,
}
impl GNNType {
pub fn default_layers(&self) -> usize {
match self {
GNNType::GCN => 2,
GNNType::GraphSAGE => 2,
GNNType::GAT => 2,
GNNType::GraphTransformer => 4,
GNNType::GIN => 3,
GNNType::PNA => 3,
GNNType::HetGNN => 2,
GNNType::TGN => 2,
}
}
pub fn requires_attention(&self) -> bool {
matches!(self, GNNType::GAT | GNNType::GraphTransformer)
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum AggregationType {
Mean,
Max,
Sum,
LSTM,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GNNConfig {
pub base_config: ModelConfig,
pub gnn_type: GNNType,
pub num_layers: usize,
pub hidden_dimensions: Vec<usize>,
pub dropout: f64,
pub aggregation: AggregationType,
pub num_heads: Option<usize>, pub sample_neighbors: Option<usize>, pub residual_connections: bool,
pub layer_norm: bool,
pub edge_features: bool,
}
impl Default for GNNConfig {
fn default() -> Self {
Self {
base_config: ModelConfig::default(),
gnn_type: GNNType::GCN,
num_layers: 2,
hidden_dimensions: vec![128, 64],
dropout: 0.1,
aggregation: AggregationType::Mean,
num_heads: None,
sample_neighbors: None,
residual_connections: true,
layer_norm: true,
edge_features: false,
}
}
}
pub struct GNNEmbedding {
id: Uuid,
config: GNNConfig,
entity_embeddings: HashMap<String, Array1<f32>>,
relation_embeddings: HashMap<String, Array1<f32>>,
entity_to_idx: HashMap<String, usize>,
relation_to_idx: HashMap<String, usize>,
idx_to_entity: HashMap<usize, String>,
idx_to_relation: HashMap<usize, String>,
adjacency_list: HashMap<usize, HashSet<(usize, usize)>>, reverse_adjacency_list: HashMap<usize, HashSet<(usize, usize)>>,
triples: Vec<Triple>,
layers: Vec<GNNLayer>,
is_trained: bool,
creation_time: chrono::DateTime<Utc>,
last_training_time: Option<chrono::DateTime<Utc>>,
}
struct GNNLayer {
weight_matrix: Array2<f32>,
bias: Array1<f32>,
attention_weights: Option<AttentionWeights>,
layer_norm: Option<LayerNormalization>,
}
struct AttentionWeights {
query_weights: Array2<f32>,
key_weights: Array2<f32>,
value_weights: Array2<f32>,
num_heads: usize,
}
struct LayerNormalization {
gamma: Array1<f32>,
beta: Array1<f32>,
epsilon: f32,
}
impl GNNEmbedding {
pub fn new(config: GNNConfig) -> Self {
Self {
id: Uuid::new_v4(),
config,
entity_embeddings: HashMap::new(),
relation_embeddings: HashMap::new(),
entity_to_idx: HashMap::new(),
relation_to_idx: HashMap::new(),
idx_to_entity: HashMap::new(),
idx_to_relation: HashMap::new(),
adjacency_list: HashMap::new(),
reverse_adjacency_list: HashMap::new(),
triples: Vec::new(),
layers: Vec::new(),
is_trained: false,
creation_time: Utc::now(),
last_training_time: None,
}
}
fn initialize_layers(&mut self) -> Result<()> {
self.layers.clear();
let mut rng = Random::seed(42);
let mut input_dim = self.config.base_config.dimensions;
let num_layers = self.config.num_layers;
for i in 0..num_layers {
let output_dim = if i == num_layers - 1 {
self.config.base_config.dimensions
} else if i < self.config.hidden_dimensions.len() {
self.config.hidden_dimensions[i]
} else {
self.config.base_config.dimensions
};
let scale = (2.0 / (input_dim + output_dim) as f32).sqrt();
let weight_matrix = Array2::from_shape_fn((input_dim, output_dim), |_| {
rng.random_range(0.0..1.0) * scale * 2.0 - scale
});
let bias = Array1::zeros(output_dim);
let attention_weights = if self.config.gnn_type.requires_attention() {
let num_heads = self.config.num_heads.unwrap_or(8);
let head_dim = output_dim / num_heads;
let attention_dim = head_dim * num_heads;
Some(AttentionWeights {
query_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
rng.random_range(0.0..1.0) * scale * 2.0 - scale
}),
key_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
rng.random_range(0.0..1.0) * scale * 2.0 - scale
}),
value_weights: Array2::from_shape_fn((input_dim, attention_dim), |_| {
rng.random_range(0.0..1.0) * scale * 2.0 - scale
}),
num_heads,
})
} else {
None
};
let layer_norm = if self.config.layer_norm {
Some(LayerNormalization {
gamma: Array1::ones(output_dim),
beta: Array1::zeros(output_dim),
epsilon: 1e-5,
})
} else {
None
};
self.layers.push(GNNLayer {
weight_matrix,
bias,
attention_weights,
layer_norm,
});
input_dim = output_dim;
}
Ok(())
}
fn build_adjacency_lists(&mut self) {
self.adjacency_list.clear();
self.reverse_adjacency_list.clear();
for triple in &self.triples {
let subject_idx = self.entity_to_idx[&triple.subject.iri];
let object_idx = self.entity_to_idx[&triple.object.iri];
let relation_idx = self.relation_to_idx[&triple.predicate.iri];
self.adjacency_list
.entry(subject_idx)
.or_default()
.insert((object_idx, relation_idx));
self.reverse_adjacency_list
.entry(object_idx)
.or_default()
.insert((subject_idx, relation_idx));
}
}
fn aggregate_neighbors(
&self,
node_idx: usize,
node_features: &HashMap<usize, Array1<f32>>,
) -> Array1<f32> {
let neighbors = self.adjacency_list.get(&node_idx);
let reverse_neighbors = self.reverse_adjacency_list.get(&node_idx);
let mut neighbor_features = Vec::new();
if let Some(neighbors) = neighbors {
for (neighbor_idx, _) in neighbors {
if let Some(feature) = node_features.get(neighbor_idx) {
neighbor_features.push(feature.clone());
}
}
}
if let Some(reverse_neighbors) = reverse_neighbors {
for (neighbor_idx, _) in reverse_neighbors {
if let Some(feature) = node_features.get(neighbor_idx) {
neighbor_features.push(feature.clone());
}
}
}
if neighbor_features.is_empty() {
return Array1::zeros(
node_features
.values()
.next()
.expect("node_features should not be empty")
.len(),
);
}
match self.config.aggregation {
AggregationType::Mean => {
let sum: Array1<f32> = neighbor_features
.iter()
.fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x);
sum / neighbor_features.len() as f32
}
AggregationType::Max => neighbor_features.iter().fold(
Array1::from_elem(neighbor_features[0].len(), f32::NEG_INFINITY),
|acc, x| {
let mut result = acc.clone();
for (i, &val) in x.iter().enumerate() {
result[i] = result[i].max(val);
}
result
},
),
AggregationType::Sum => neighbor_features
.iter()
.fold(Array1::zeros(neighbor_features[0].len()), |acc, x| acc + x),
AggregationType::LSTM => {
self.aggregate_neighbors_lstm(&neighbor_features)
}
}
}
fn aggregate_neighbors_lstm(&self, neighbor_features: &[Array1<f32>]) -> Array1<f32> {
let mut aggregated = Array1::zeros(neighbor_features[0].len());
for feature in neighbor_features {
aggregated = aggregated * 0.8 + feature * 0.2; }
aggregated
}
fn apply_layer(
&self,
layer: &GNNLayer,
node_features: &HashMap<usize, Array1<f32>>,
) -> HashMap<usize, Array1<f32>> {
let mut new_features = HashMap::new();
match self.config.gnn_type {
GNNType::GCN => self.apply_gcn_layer(layer, node_features, &mut new_features),
GNNType::GraphSAGE => {
self.apply_graphsage_layer(layer, node_features, &mut new_features)
}
GNNType::GAT => self.apply_gat_layer(layer, node_features, &mut new_features),
GNNType::GIN => self.apply_gin_layer(layer, node_features, &mut new_features),
_ => self.apply_gcn_layer(layer, node_features, &mut new_features), }
new_features
}
fn apply_gcn_layer(
&self,
layer: &GNNLayer,
node_features: &HashMap<usize, Array1<f32>>,
new_features: &mut HashMap<usize, Array1<f32>>,
) {
for (node_idx, feature) in node_features {
let aggregated = self.aggregate_neighbors(*node_idx, node_features);
let combined = feature + &aggregated;
let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
let activated = transformed.mapv(|x| x.max(0.0));
let output = if let Some(ln) = &layer.layer_norm {
self.apply_layer_norm(&activated, ln)
} else {
activated
};
new_features.insert(*node_idx, output);
}
}
fn apply_graphsage_layer(
&self,
layer: &GNNLayer,
node_features: &HashMap<usize, Array1<f32>>,
new_features: &mut HashMap<usize, Array1<f32>>,
) {
for (node_idx, feature) in node_features {
let aggregated = self.aggregate_neighbors(*node_idx, node_features);
let node_transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
let neighbor_transformed = aggregated.dot(&layer.weight_matrix) + &layer.bias;
let combined = &node_transformed + &neighbor_transformed;
let activated = combined.mapv(|x| x.max(0.0));
let normalized = &activated / (activated.dot(&activated).sqrt() + 1e-6);
new_features.insert(*node_idx, normalized);
}
}
fn apply_gat_layer(
&self,
layer: &GNNLayer,
node_features: &HashMap<usize, Array1<f32>>,
new_features: &mut HashMap<usize, Array1<f32>>,
) {
let attention = layer
.attention_weights
.as_ref()
.expect("attention_weights should be initialized for GAT layer");
for (node_idx, feature) in node_features {
let mut neighbor_indices = Vec::new();
if let Some(neighbors) = self.adjacency_list.get(node_idx) {
neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
}
if let Some(neighbors) = self.reverse_adjacency_list.get(node_idx) {
neighbor_indices.extend(neighbors.iter().map(|(n, _)| *n));
}
if neighbor_indices.is_empty() {
let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
let activated = transformed.mapv(|x| x.max(0.0));
new_features.insert(*node_idx, activated);
continue;
}
if feature.len() != attention.query_weights.shape()[0] {
let aggregated = self.aggregate_neighbors(*node_idx, node_features);
let combined = feature + &aggregated;
let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
let activated = transformed.mapv(|x| x.max(0.0));
new_features.insert(*node_idx, activated);
continue;
}
let query = feature.dot(&attention.query_weights);
let mut attention_scores = Vec::new();
let mut neighbor_values = Vec::new();
for neighbor_idx in &neighbor_indices {
if let Some(neighbor_feature) = node_features.get(neighbor_idx) {
if neighbor_feature.len() != attention.key_weights.shape()[0] {
continue;
}
let key = neighbor_feature.dot(&attention.key_weights);
let value = neighbor_feature.dot(&attention.value_weights);
if query.len() == key.len() {
let score = query.dot(&key) / (attention.num_heads as f32).sqrt();
attention_scores.push(score);
neighbor_values.push(value);
}
}
}
if attention_scores.is_empty() {
let aggregated = self.aggregate_neighbors(*node_idx, node_features);
let combined = feature + &aggregated;
let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
let activated = transformed.mapv(|x| x.max(0.0));
new_features.insert(*node_idx, activated);
continue;
}
let max_score = attention_scores
.iter()
.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_scores: Vec<f32> = attention_scores
.iter()
.map(|&s| (s - max_score).exp())
.collect();
let sum_exp = exp_scores.iter().sum::<f32>();
let attention_weights: Vec<f32> =
exp_scores.iter().copied().map(|e| e / sum_exp).collect();
let output_dim = layer.weight_matrix.shape()[1];
let mut aggregated = Array1::<f32>::zeros(output_dim);
for (i, value) in neighbor_values.iter().enumerate() {
let min_dim = aggregated.len().min(value.len());
for j in 0..min_dim {
aggregated[j] += value[j] * attention_weights[i];
}
}
let transformed = feature.dot(&layer.weight_matrix) + &layer.bias;
let combined =
if self.config.residual_connections && transformed.len() == aggregated.len() {
transformed + &aggregated
} else {
transformed
};
let activated = combined.mapv(|x| x.max(0.0));
new_features.insert(*node_idx, activated);
}
}
fn apply_gin_layer(
&self,
layer: &GNNLayer,
node_features: &HashMap<usize, Array1<f32>>,
new_features: &mut HashMap<usize, Array1<f32>>,
) {
let epsilon = 0.0;
for (node_idx, feature) in node_features {
let aggregated = self.aggregate_neighbors(*node_idx, node_features);
let combined = (1.0 + epsilon) * feature + aggregated;
let transformed = combined.dot(&layer.weight_matrix) + &layer.bias;
let activated = transformed.mapv(|x| x.max(0.0));
new_features.insert(*node_idx, activated);
}
}
fn apply_layer_norm(&self, input: &Array1<f32>, ln: &LayerNormalization) -> Array1<f32> {
let mean = input.mean().unwrap_or(0.0);
let variance = input.mapv(|x| (x - mean).powi(2)).mean().unwrap_or(1.0);
let normalized = input.mapv(|x| (x - mean) / (variance + ln.epsilon).sqrt());
&normalized * &ln.gamma + &ln.beta
}
fn forward(
&self,
initial_features: HashMap<usize, Array1<f32>>,
) -> HashMap<usize, Array1<f32>> {
let mut features = initial_features;
for layer in self.layers.iter() {
let new_features = self.apply_layer(layer, &features);
let dropout_rate = self.config.dropout;
let mut rng = Random::seed(42);
features = new_features
.into_iter()
.map(|(idx, feat)| {
let masked = feat.mapv(|x| {
if rng.random_range(0.0..1.0) > dropout_rate as f32 {
x / (1.0 - dropout_rate as f32)
} else {
0.0
}
});
(idx, masked)
})
.collect();
}
features
}
}
#[async_trait]
impl EmbeddingModel for GNNEmbedding {
fn config(&self) -> &ModelConfig {
&self.config.base_config
}
fn model_id(&self) -> &Uuid {
&self.id
}
fn model_type(&self) -> &'static str {
"GNNEmbedding"
}
fn add_triple(&mut self, triple: Triple) -> Result<()> {
let subject = triple.subject.iri.clone();
let object = triple.object.iri.clone();
let predicate = triple.predicate.iri.clone();
if !self.entity_to_idx.contains_key(&subject) {
let idx = self.entity_to_idx.len();
self.entity_to_idx.insert(subject.clone(), idx);
self.idx_to_entity.insert(idx, subject);
}
if !self.entity_to_idx.contains_key(&object) {
let idx = self.entity_to_idx.len();
self.entity_to_idx.insert(object.clone(), idx);
self.idx_to_entity.insert(idx, object);
}
if !self.relation_to_idx.contains_key(&predicate) {
let idx = self.relation_to_idx.len();
self.relation_to_idx.insert(predicate.clone(), idx);
self.idx_to_relation.insert(idx, predicate);
}
self.triples.push(triple);
self.is_trained = false;
Ok(())
}
async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
let start_time = std::time::Instant::now();
let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
self.build_adjacency_lists();
self.initialize_layers()?;
let mut rng = Random::seed(42);
let dimensions = self.config.base_config.dimensions;
let mut initial_features = HashMap::new();
for idx in self.entity_to_idx.values() {
let embedding =
Array1::from_shape_fn(dimensions, |_| rng.random_range(0.0..1.0) * 0.1 - 0.05);
initial_features.insert(*idx, embedding);
}
let mut loss_history = Vec::new();
for _epoch in 0..epochs {
let output_features = self.forward(initial_features.clone());
let loss = output_features
.values()
.map(|f| f.mapv(|x| x * x).sum())
.sum::<f32>()
/ output_features.len() as f32;
loss_history.push(loss as f64);
initial_features = output_features;
if loss < 0.001 {
break;
}
}
for (idx, embedding) in initial_features {
if let Some(entity) = self.idx_to_entity.get(&idx) {
self.entity_embeddings.insert(entity.clone(), embedding);
}
}
for relation in self.relation_to_idx.keys() {
let embedding =
Array1::from_shape_fn(dimensions, |_| rng.random_range(0.0..1.0) * 0.1 - 0.05);
self.relation_embeddings.insert(relation.clone(), embedding);
}
self.is_trained = true;
self.last_training_time = Some(Utc::now());
Ok(TrainingStats {
epochs_completed: loss_history.len(),
final_loss: *loss_history.last().unwrap_or(&0.0),
training_time_seconds: start_time.elapsed().as_secs_f64(),
convergence_achieved: loss_history.last().unwrap_or(&1.0) < &0.001,
loss_history,
})
}
fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
self.entity_embeddings
.get(entity)
.map(|e| Vector::new(e.to_vec()))
.ok_or_else(|| {
EmbeddingError::EntityNotFound {
entity: entity.to_string(),
}
.into()
})
}
fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
self.relation_embeddings
.get(relation)
.map(|e| Vector::new(e.to_vec()))
.ok_or_else(|| {
EmbeddingError::RelationNotFound {
relation: relation.to_string(),
}
.into()
})
}
fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
let subj_emb =
self.entity_embeddings
.get(subject)
.ok_or_else(|| EmbeddingError::EntityNotFound {
entity: subject.to_string(),
})?;
let pred_emb = self.relation_embeddings.get(predicate).ok_or_else(|| {
EmbeddingError::RelationNotFound {
relation: predicate.to_string(),
}
})?;
let obj_emb =
self.entity_embeddings
.get(object)
.ok_or_else(|| EmbeddingError::EntityNotFound {
entity: object.to_string(),
})?;
let transformed = (subj_emb + pred_emb) * obj_emb;
Ok(transformed.sum() as f64)
}
fn predict_objects(
&self,
subject: &str,
predicate: &str,
k: usize,
) -> Result<Vec<(String, f64)>> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
let mut scores = Vec::new();
for entity in self.entity_to_idx.keys() {
if let Ok(score) = self.score_triple(subject, predicate, entity) {
scores.push((entity.clone(), score));
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
Ok(scores)
}
fn predict_subjects(
&self,
predicate: &str,
object: &str,
k: usize,
) -> Result<Vec<(String, f64)>> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
let mut scores = Vec::new();
for entity in self.entity_to_idx.keys() {
if let Ok(score) = self.score_triple(entity, predicate, object) {
scores.push((entity.clone(), score));
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
Ok(scores)
}
fn predict_relations(
&self,
subject: &str,
object: &str,
k: usize,
) -> Result<Vec<(String, f64)>> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
let mut scores = Vec::new();
for relation in self.relation_to_idx.keys() {
if let Ok(score) = self.score_triple(subject, relation, object) {
scores.push((relation.clone(), score));
}
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
Ok(scores)
}
fn get_entities(&self) -> Vec<String> {
self.entity_to_idx.keys().cloned().collect()
}
fn get_relations(&self) -> Vec<String> {
self.relation_to_idx.keys().cloned().collect()
}
fn get_stats(&self) -> ModelStats {
ModelStats {
num_entities: self.entity_to_idx.len(),
num_relations: self.relation_to_idx.len(),
num_triples: self.triples.len(),
dimensions: self.config.base_config.dimensions,
is_trained: self.is_trained,
model_type: format!("GNNEmbedding-{:?}", self.config.gnn_type),
creation_time: self.creation_time,
last_training_time: self.last_training_time,
}
}
fn save(&self, _path: &str) -> Result<()> {
Ok(())
}
fn load(&mut self, _path: &str) -> Result<()> {
Ok(())
}
fn clear(&mut self) {
self.entity_embeddings.clear();
self.relation_embeddings.clear();
self.entity_to_idx.clear();
self.relation_to_idx.clear();
self.idx_to_entity.clear();
self.idx_to_relation.clear();
self.adjacency_list.clear();
self.reverse_adjacency_list.clear();
self.triples.clear();
self.layers.clear();
self.is_trained = false;
}
fn is_trained(&self) -> bool {
self.is_trained
}
async fn encode(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Err(anyhow!(
"Knowledge graph embedding model does not support text encoding"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::NamedNode;
#[tokio::test]
async fn test_gnn_embedding_basic() {
let config = GNNConfig {
gnn_type: GNNType::GCN,
num_layers: 2,
hidden_dimensions: vec![64, 32],
..Default::default()
};
let mut model = GNNEmbedding::new(config);
let triple1 = Triple::new(
NamedNode::new("http://example.org/Alice").expect("should succeed"),
NamedNode::new("http://example.org/knows").expect("should succeed"),
NamedNode::new("http://example.org/Bob").expect("should succeed"),
);
let triple2 = Triple::new(
NamedNode::new("http://example.org/Bob").expect("should succeed"),
NamedNode::new("http://example.org/knows").expect("should succeed"),
NamedNode::new("http://example.org/Charlie").expect("should succeed"),
);
model.add_triple(triple1).expect("should succeed");
model.add_triple(triple2).expect("should succeed");
let _stats = model.train(Some(10)).await.expect("should succeed");
assert!(model.is_trained());
let alice_emb = model
.get_entity_embedding("http://example.org/Alice")
.expect("should succeed");
assert_eq!(alice_emb.dimensions, 100);
let predictions = model
.predict_objects("http://example.org/Alice", "http://example.org/knows", 5)
.expect("should succeed");
assert!(!predictions.is_empty());
}
#[tokio::test]
async fn test_gnn_types() {
for gnn_type in [GNNType::GCN, GNNType::GraphSAGE, GNNType::GAT, GNNType::GIN] {
let config = GNNConfig {
gnn_type,
num_heads: if gnn_type == GNNType::GAT {
Some(4)
} else {
None
},
..Default::default()
};
let mut model = GNNEmbedding::new(config);
let triple = Triple::new(
NamedNode::new("http://example.org/A").expect("should succeed"),
NamedNode::new("http://example.org/rel").expect("should succeed"),
NamedNode::new("http://example.org/B").expect("should succeed"),
);
model.add_triple(triple).expect("should succeed");
let _stats = model.train(Some(5)).await.expect("should succeed");
assert!(model.is_trained());
}
}
}