#![allow(
clippy::expect_used,
clippy::uninlined_format_args,
clippy::unwrap_used,
clippy::inefficient_to_string,
clippy::unused_async,
clippy::similar_names,
clippy::field_reassign_with_default,
clippy::unnecessary_literal_bound,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
clippy::too_many_lines,
clippy::borrowed_box,
clippy::redundant_closure_for_method_calls,
unused_imports
)]
use chrono::Utc;
use do_memory_core::embeddings::{EmbeddingProvider, LocalConfig, LocalEmbeddingProvider};
use do_memory_core::{
ComplexityLevel, ExecutionResult, ExecutionStep, SelfLearningMemory, TaskContext, TaskOutcome,
TaskType,
};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
println!("\nMemory-Core: End-to-End Embeddings Example");
println!("{}", "=".repeat(60));
println!();
println!("Step 1: Initializing memory system...");
let memory = SelfLearningMemory::new();
println!("Memory system initialized\n");
println!("Step 2: Initializing embedding provider...");
let provider = initialize_provider().await?;
println!("Using provider: {}", provider_name(&provider));
println!("Dimension: {}\n", provider.embedding_dimension());
println!("Step 3: Creating sample episodes with embeddings...");
let episodes = vec![
(
"Implement REST API authentication",
"web-api",
vec!["rest", "auth", "jwt"],
),
(
"Build OAuth2 login flow",
"web-api",
vec!["oauth", "authentication", "security"],
),
(
"Optimize database query performance",
"database",
vec!["sql", "performance", "indexing"],
),
(
"Create React component for user profile",
"frontend",
vec!["react", "ui", "components"],
),
(
"Implement WebSocket real-time notifications",
"web-api",
vec!["websocket", "realtime", "notifications"],
),
];
for (desc, domain, tags) in &episodes {
let embedding = provider.embed_text(desc).await?;
let embedding_len = embedding.len();
println!(
" Created episode: '{}' (embedding: {} dims)",
desc, embedding_len
);
let context = TaskContext {
domain: domain.to_string(),
language: Some("rust".to_string()),
framework: Some("axum".to_string()),
complexity: ComplexityLevel::Moderate,
tags: tags.iter().map(|s| s.to_string()).collect(),
};
let episode_id = memory
.start_episode(desc.to_string(), context, TaskType::CodeGeneration)
.await;
let step1 = ExecutionStep {
step_number: 1,
timestamp: Utc::now(),
tool: "analyze".to_string(),
action: "Analyzing requirements".to_string(),
parameters: serde_json::json!({}),
result: Some(ExecutionResult::Success {
output: "Requirements analyzed".to_string(),
}),
latency_ms: 10,
tokens_used: None,
metadata: std::collections::HashMap::new(),
};
memory.log_step(episode_id, step1).await;
let step2 = ExecutionStep {
step_number: 2,
timestamp: Utc::now(),
tool: "implement".to_string(),
action: "Implementing solution".to_string(),
parameters: serde_json::json!({}),
result: Some(ExecutionResult::Success {
output: "Implementation complete".to_string(),
}),
latency_ms: 10,
tokens_used: None,
metadata: std::collections::HashMap::new(),
};
memory.log_step(episode_id, step2).await;
let step3 = ExecutionStep {
step_number: 3,
timestamp: Utc::now(),
tool: "test".to_string(),
action: "Testing solution".to_string(),
parameters: serde_json::json!({}),
result: Some(ExecutionResult::Success {
output: "All tests passed".to_string(),
}),
latency_ms: 10,
tokens_used: None,
metadata: std::collections::HashMap::new(),
};
memory.log_step(episode_id, step3).await;
memory
.complete_episode(
episode_id,
TaskOutcome::Success {
verdict: "Task completed successfully".to_string(),
artifacts: vec![],
},
)
.await?;
}
println!("Created {} episodes\n", episodes.len());
println!("Step 4: Performing semantic similarity searches...");
println!();
let queries = vec![
("user authentication", "web-api", vec!["auth"]),
("database optimization", "database", vec!["performance"]),
("frontend UI component", "frontend", vec!["ui"]),
];
for (query, domain, tags) in &queries {
println!("Query: \"{}\"", query);
println!("{}", "-".repeat(60));
let query_embedding = provider.embed_text(query).await?;
let query_embedding_len = query_embedding.len();
println!(" Query embedding: {} dimensions", query_embedding_len);
let context = TaskContext {
domain: domain.to_string(),
language: Some("rust".to_string()),
framework: Some("axum".to_string()),
complexity: ComplexityLevel::Moderate,
tags: tags.iter().map(|s| s.to_string()).collect(),
};
let relevant = memory
.retrieve_relevant_context(query.to_string(), context, 3)
.await;
println!(" Found {} relevant episodes:", relevant.len());
for (i, episode) in relevant.iter().enumerate() {
println!(" {}. {}", i + 1, episode.task_description);
if let Some(reward) = &episode.reward {
println!(" Reward: {:.2}", reward.total);
}
}
println!();
}
println!("Step 5: Direct similarity calculations...");
println!();
let text_pairs = vec![
("REST API", "web service API"),
("OAuth authentication", "user login"),
("database indexing", "React components"),
];
for (text1, text2) in text_pairs {
let similarity = provider.similarity(text1, text2).await?;
println!(" Similarity('{}', '{}') = {:.3}", text1, text2, similarity);
}
println!();
println!("Step 6: Batch embedding generation...");
let batch_texts = vec![
"Implement user authentication".to_string(),
"Create database migration".to_string(),
"Build API endpoint".to_string(),
"Write unit tests".to_string(),
];
let batch_results = provider.embed_batch(&batch_texts).await?;
println!(" Generated {} embeddings in batch", batch_results.len());
for (i, embedding) in batch_results.iter().enumerate() {
println!(
" {}. '{}' → {} dims",
i + 1,
batch_texts[i],
embedding.len()
);
}
println!();
println!("Example Complete!");
println!();
println!("Key Takeaways:");
println!(" - Embeddings enable semantic (meaning-based) search");
println!(" - Multiple providers supported (local, OpenAI, etc.)");
println!(" - Batch processing available for efficiency");
println!(" - Seamlessly integrates with memory system");
println!();
println!("Next Steps:");
println!(" 1. Try with different providers (see EMBEDDING_PROVIDERS.md)");
println!(" 2. Experiment with different similarity thresholds");
println!(" 3. Integrate with your own application");
println!(" 4. See memory-cli for command-line usage");
Ok(())
}
async fn initialize_provider() -> anyhow::Result<Box<dyn EmbeddingProvider>> {
#[cfg(feature = "local-embeddings")]
{
match LocalEmbeddingProvider::new(LocalConfig::default()).await {
Ok(provider) => {
println!(" Using Local Embedding Provider (CPU-based)");
return Ok(Box::new(provider));
}
Err(e) => {
println!(" Local provider failed: {}", e);
}
}
}
#[cfg(feature = "openai")]
{
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
use do_memory_core::embeddings::OpenAIConfig;
match do_memory_core::embeddings::OpenAIEmbeddingProvider::new(
api_key,
OpenAIConfig::text_embedding_3_small(),
) {
Ok(provider) => {
println!(" Using OpenAI Embedding Provider");
return Ok(Box::new(provider));
}
Err(e) => {
println!(" OpenAI provider failed: {}", e);
}
}
}
}
println!(" Using Mock Provider (random embeddings - not semantically meaningful)");
println!(" For production use, enable 'openai' or 'local-embeddings' feature");
Ok(Box::new(MockEmbeddingProvider))
}
fn provider_name(provider: &Box<dyn EmbeddingProvider>) -> &str {
let dim = provider.embedding_dimension();
match dim {
384 => "Local (sentence-transformers)",
1536 => "OpenAI (text-embedding-3-small)",
768 => "Mock Provider",
_ => "Custom Provider",
}
}
struct MockEmbeddingProvider;
#[async_trait::async_trait]
impl EmbeddingProvider for MockEmbeddingProvider {
async fn embed_text(&self, text: &str) -> anyhow::Result<Vec<f32>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
let hash = hasher.finish();
let dimension = 768;
let mut embedding = Vec::with_capacity(dimension);
let mut seed = hash;
for _ in 0..dimension {
seed = seed.wrapping_mul(1_103_515_245).wrapping_add(12345);
let value = ((seed >> 16) as f32) / 32768.0 - 1.0;
embedding.push(value);
}
let magnitude = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for x in &mut embedding {
*x /= magnitude;
}
}
Ok(embedding)
}
fn embedding_dimension(&self) -> usize {
768
}
fn model_name(&self) -> &str {
"mock-provider"
}
}