use oxify_model::{Node, NodeId, NodeKind};
use std::collections::HashMap;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum BatchGroup {
LlmProvider(String),
VectorDatabase(String),
ToolServer(String),
Parallel,
None,
}
#[derive(Debug, Clone)]
pub struct BatchPlan {
pub batches: Vec<Batch>,
pub individual_nodes: Vec<NodeId>,
}
#[derive(Debug, Clone)]
pub struct Batch {
pub group: BatchGroup,
pub nodes: Vec<NodeId>,
pub speedup_factor: f32,
}
impl Batch {
pub fn new(group: BatchGroup, nodes: Vec<NodeId>) -> Self {
let speedup_factor = calculate_speedup_factor(&group, nodes.len());
Self {
group,
nodes,
speedup_factor,
}
}
pub fn size(&self) -> usize {
self.nodes.len()
}
}
fn calculate_speedup_factor(group: &BatchGroup, batch_size: usize) -> f32 {
if batch_size <= 1 {
return 1.0;
}
match group {
BatchGroup::LlmProvider(_) => {
1.0 + (batch_size as f32 * 0.15).min(0.5)
}
BatchGroup::VectorDatabase(_) => {
1.0 + (batch_size as f32 * 0.25).min(0.75)
}
BatchGroup::ToolServer(_) => {
1.0 + (batch_size as f32 * 0.2).min(0.6)
}
BatchGroup::Parallel => {
(batch_size as f32).min(8.0)
}
BatchGroup::None => 1.0,
}
}
pub struct BatchAnalyzer {
pub min_batch_size: usize,
pub max_batch_size: usize,
}
impl Default for BatchAnalyzer {
fn default() -> Self {
Self::new()
}
}
impl BatchAnalyzer {
pub fn new() -> Self {
Self {
min_batch_size: 2,
max_batch_size: 10,
}
}
pub fn with_limits(min_batch_size: usize, max_batch_size: usize) -> Self {
Self {
min_batch_size,
max_batch_size,
}
}
pub fn analyze(&self, nodes: &[&Node]) -> BatchPlan {
let mut groups: HashMap<BatchGroup, Vec<NodeId>> = HashMap::new();
for node in nodes {
let group = self.classify_node(node);
groups.entry(group).or_default().push(node.id);
}
let mut batches = Vec::new();
let mut individual_nodes = Vec::new();
for (group, node_ids) in groups {
if matches!(group, BatchGroup::None) {
individual_nodes.extend(node_ids);
} else if node_ids.len() >= self.min_batch_size {
for chunk in node_ids.chunks(self.max_batch_size) {
batches.push(Batch::new(group.clone(), chunk.to_vec()));
}
} else {
individual_nodes.extend(node_ids);
}
}
BatchPlan {
batches,
individual_nodes,
}
}
fn classify_node(&self, node: &Node) -> BatchGroup {
match &node.kind {
NodeKind::LLM(config) => {
BatchGroup::LlmProvider(config.provider.clone())
}
NodeKind::Retriever(config) => {
BatchGroup::VectorDatabase(config.db_type.clone())
}
NodeKind::Tool(config) => {
BatchGroup::ToolServer(config.server_id.clone())
}
NodeKind::Code(_) | NodeKind::IfElse(_) | NodeKind::Switch(_) => {
BatchGroup::Parallel
}
NodeKind::Start
| NodeKind::End
| NodeKind::Loop(_)
| NodeKind::TryCatch(_)
| NodeKind::SubWorkflow(_)
| NodeKind::Parallel(_)
| NodeKind::Approval(_)
| NodeKind::Form(_)
| NodeKind::Vision(_) => BatchGroup::None,
}
}
pub fn estimate_time_savings(&self, plan: &BatchPlan) -> f32 {
let mut total_speedup = 0.0;
for batch in &plan.batches {
let time_saved_ratio = 1.0 - (1.0 / batch.speedup_factor);
total_speedup += time_saved_ratio * batch.size() as f32;
}
let total_nodes =
plan.batches.iter().map(|b| b.size()).sum::<usize>() + plan.individual_nodes.len();
if total_nodes > 0 {
total_speedup / total_nodes as f32
} else {
0.0
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BatchStats {
pub total_nodes: usize,
pub batched_nodes: usize,
pub batch_count: usize,
pub estimated_time_savings: f32,
pub average_batch_size: f32,
}
impl BatchStats {
pub fn from_plan(plan: &BatchPlan, analyzer: &BatchAnalyzer) -> Self {
let batched_nodes: usize = plan.batches.iter().map(|b| b.size()).sum();
let batch_count = plan.batches.len();
let total_nodes = batched_nodes + plan.individual_nodes.len();
let average_batch_size = if batch_count > 0 {
batched_nodes as f32 / batch_count as f32
} else {
0.0
};
let estimated_time_savings = analyzer.estimate_time_savings(plan);
Self {
total_nodes,
batched_nodes,
batch_count,
estimated_time_savings,
average_batch_size,
}
}
pub fn efficiency(&self) -> f32 {
if self.total_nodes > 0 {
self.batched_nodes as f32 / self.total_nodes as f32
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxify_model::{LlmConfig, VectorConfig};
#[test]
fn test_batch_classification() {
let analyzer = BatchAnalyzer::new();
let llm_node = Node::new(
"LLM".to_string(),
NodeKind::LLM(LlmConfig {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
system_prompt: None,
prompt_template: "test".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
tools: Vec::new(),
images: Vec::new(),
extra_params: serde_json::Value::Null,
}),
);
let group = analyzer.classify_node(&llm_node);
assert_eq!(group, BatchGroup::LlmProvider("openai".to_string()));
}
#[test]
fn test_batch_plan_creation() {
let analyzer = BatchAnalyzer::new();
let nodes = [
Node::new(
"LLM 1".to_string(),
NodeKind::LLM(LlmConfig {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
system_prompt: None,
prompt_template: "test1".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
tools: Vec::new(),
images: Vec::new(),
extra_params: serde_json::Value::Null,
}),
),
Node::new(
"LLM 2".to_string(),
NodeKind::LLM(LlmConfig {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
system_prompt: None,
prompt_template: "test2".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
tools: Vec::new(),
images: Vec::new(),
extra_params: serde_json::Value::Null,
}),
),
Node::new(
"Vector".to_string(),
NodeKind::Retriever(VectorConfig {
db_type: "qdrant".to_string(),
collection: "docs".to_string(),
query: "test".to_string(),
top_k: 5,
score_threshold: Some(0.7),
}),
),
];
let node_refs: Vec<&Node> = nodes.iter().collect();
let plan = analyzer.analyze(&node_refs);
assert_eq!(plan.batches.len(), 1);
assert_eq!(plan.individual_nodes.len(), 1);
assert_eq!(plan.batches[0].size(), 2);
}
#[test]
fn test_speedup_calculation() {
let llm_group = BatchGroup::LlmProvider("openai".to_string());
let speedup_2 = calculate_speedup_factor(&llm_group, 2);
let speedup_10 = calculate_speedup_factor(&llm_group, 10);
assert!(speedup_2 > 1.0);
assert!(speedup_10 > speedup_2);
assert!(speedup_10 <= 1.5); }
#[test]
fn test_batch_stats() {
let analyzer = BatchAnalyzer::new();
let nodes = [
Node::new(
"LLM 1".to_string(),
NodeKind::LLM(LlmConfig {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
system_prompt: None,
prompt_template: "test1".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
tools: Vec::new(),
images: Vec::new(),
extra_params: serde_json::Value::Null,
}),
),
Node::new(
"LLM 2".to_string(),
NodeKind::LLM(LlmConfig {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
system_prompt: None,
prompt_template: "test2".to_string(),
temperature: Some(0.7),
max_tokens: Some(1000),
tools: Vec::new(),
images: Vec::new(),
extra_params: serde_json::Value::Null,
}),
),
];
let node_refs: Vec<&Node> = nodes.iter().collect();
let plan = analyzer.analyze(&node_refs);
let stats = BatchStats::from_plan(&plan, &analyzer);
assert_eq!(stats.total_nodes, 2);
assert_eq!(stats.batched_nodes, 2);
assert_eq!(stats.batch_count, 1);
assert!(stats.efficiency() > 0.9); assert!(stats.estimated_time_savings > 0.0);
}
#[test]
fn test_max_batch_size_splitting() {
let analyzer = BatchAnalyzer::with_limits(2, 3);
let mut nodes = Vec::new();
for i in 0..5 {
nodes.push(Node::new(
format!("LLM {}", i),
NodeKind::LLM(LlmConfig {
provider: "openai".to_string(),
model: "gpt-4".to_string(),
system_prompt: None,
prompt_template: format!("test{}", i),
temperature: Some(0.7),
max_tokens: Some(1000),
tools: Vec::new(),
images: Vec::new(),
extra_params: serde_json::Value::Null,
}),
));
}
let node_refs: Vec<&Node> = nodes.iter().collect();
let plan = analyzer.analyze(&node_refs);
assert_eq!(plan.batches.len(), 2);
assert_eq!(plan.batches[0].size(), 3);
assert_eq!(plan.batches[1].size(), 2);
}
}