use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, Int64Array, ListArray, RecordBatch, StringArray,
UInt32Array,
};
use arrow_schema::{DataType, Field, Schema};
use crate::store::{CodeBlock, CommitBlock, DocumentBlock, TextBlock};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BatchConverter {
vector_dim: usize,
}
impl BatchConverter {
pub fn new(vector_dim: usize) -> Self {
Self { vector_dim }
}
pub fn code_block_to_batch(
&self,
blocks: &[CodeBlock],
embeddings: &[Vec<f32>],
) -> Result<RecordBatch> {
if blocks.len() != embeddings.len() {
return Err(anyhow::anyhow!(
"Number of blocks and embeddings must match"
));
}
if blocks.is_empty() {
return Err(anyhow::anyhow!("Empty blocks array"));
}
for (i, embedding) in embeddings.iter().enumerate() {
if embedding.len() != self.vector_dim {
return Err(anyhow::anyhow!(
"Embedding at index {} has dimension {} but expected {}",
i,
embedding.len(),
self.vector_dim
));
}
}
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("path", DataType::Utf8, false),
Field::new("language", DataType::Utf8, false),
Field::new("content", DataType::Utf8, false),
Field::new("symbols", DataType::Utf8, true), Field::new("start_line", DataType::UInt32, false),
Field::new("end_line", DataType::UInt32, false),
Field::new("hash", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
),
true,
),
]));
let ids: Vec<String> = blocks.iter().map(|_| Uuid::new_v4().to_string()).collect();
let paths: Vec<&str> = blocks.iter().map(|b| b.path.as_str()).collect();
let languages: Vec<&str> = blocks.iter().map(|b| b.language.as_str()).collect();
let contents: Vec<&str> = blocks.iter().map(|b| b.content.as_str()).collect();
let symbols: Vec<String> = blocks
.iter()
.map(|b| serde_json::to_string(&b.symbols).unwrap_or_default())
.collect();
let start_lines: Vec<u32> = blocks.iter().map(|b| b.start_line as u32).collect();
let end_lines: Vec<u32> = blocks.iter().map(|b| b.end_line as u32).collect();
let hashes: Vec<&str> = blocks.iter().map(|b| b.hash.as_str()).collect();
let mut flattened_embeddings = Vec::with_capacity(blocks.len() * self.vector_dim);
for embedding in embeddings {
flattened_embeddings.extend_from_slice(embedding);
}
let values = Float32Array::from(flattened_embeddings);
let embedding_array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
Arc::new(values),
None, );
let expected_len = blocks.len();
assert_eq!(ids.len(), expected_len, "ids array length mismatch");
assert_eq!(paths.len(), expected_len, "paths array length mismatch");
assert_eq!(
languages.len(),
expected_len,
"languages array length mismatch"
);
assert_eq!(
contents.len(),
expected_len,
"contents array length mismatch"
);
assert_eq!(symbols.len(), expected_len, "symbols array length mismatch");
assert_eq!(
start_lines.len(),
expected_len,
"start_lines array length mismatch"
);
assert_eq!(
end_lines.len(),
expected_len,
"end_lines array length mismatch"
);
assert_eq!(hashes.len(), expected_len, "hashes array length mismatch");
assert_eq!(
embedding_array.len(),
expected_len,
"embedding array length mismatch"
);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids)),
Arc::new(StringArray::from(paths)),
Arc::new(StringArray::from(languages)),
Arc::new(StringArray::from(contents)),
Arc::new(StringArray::from(symbols)),
Arc::new(UInt32Array::from(start_lines)),
Arc::new(UInt32Array::from(end_lines)),
Arc::new(StringArray::from(hashes)),
Arc::new(embedding_array),
],
)?;
Ok(batch)
}
pub fn text_block_to_batch(
&self,
blocks: &[TextBlock],
embeddings: &[Vec<f32>],
) -> Result<RecordBatch> {
if blocks.len() != embeddings.len() {
return Err(anyhow::anyhow!(
"Number of blocks and embeddings must match"
));
}
if blocks.is_empty() {
return Err(anyhow::anyhow!("Empty blocks array"));
}
for (i, embedding) in embeddings.iter().enumerate() {
if embedding.len() != self.vector_dim {
return Err(anyhow::anyhow!(
"Embedding at index {} has dimension {} but expected {}",
i,
embedding.len(),
self.vector_dim
));
}
}
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("path", DataType::Utf8, false),
Field::new("language", DataType::Utf8, false),
Field::new("content", DataType::Utf8, false),
Field::new("start_line", DataType::UInt32, false),
Field::new("end_line", DataType::UInt32, false),
Field::new("hash", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
),
true,
),
]));
let ids: Vec<String> = blocks.iter().map(|_| Uuid::new_v4().to_string()).collect();
let paths: Vec<&str> = blocks.iter().map(|b| b.path.as_str()).collect();
let languages: Vec<&str> = blocks.iter().map(|b| b.language.as_str()).collect();
let contents: Vec<&str> = blocks.iter().map(|b| b.content.as_str()).collect();
let start_lines: Vec<u32> = blocks.iter().map(|b| b.start_line as u32).collect();
let end_lines: Vec<u32> = blocks.iter().map(|b| b.end_line as u32).collect();
let hashes: Vec<&str> = blocks.iter().map(|b| b.hash.as_str()).collect();
let mut flattened_embeddings = Vec::with_capacity(blocks.len() * self.vector_dim);
for embedding in embeddings {
flattened_embeddings.extend_from_slice(embedding);
}
let values = Float32Array::from(flattened_embeddings);
let embedding_array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
Arc::new(values),
None, );
let expected_len = blocks.len();
assert_eq!(ids.len(), expected_len, "ids array length mismatch");
assert_eq!(paths.len(), expected_len, "paths array length mismatch");
assert_eq!(
languages.len(),
expected_len,
"languages array length mismatch"
);
assert_eq!(
contents.len(),
expected_len,
"contents array length mismatch"
);
assert_eq!(
start_lines.len(),
expected_len,
"start_lines array length mismatch"
);
assert_eq!(
end_lines.len(),
expected_len,
"end_lines array length mismatch"
);
assert_eq!(hashes.len(), expected_len, "hashes array length mismatch");
assert_eq!(
embedding_array.len(),
expected_len,
"embedding array length mismatch"
);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids)),
Arc::new(StringArray::from(paths)),
Arc::new(StringArray::from(languages)),
Arc::new(StringArray::from(contents)),
Arc::new(UInt32Array::from(start_lines)),
Arc::new(UInt32Array::from(end_lines)),
Arc::new(StringArray::from(hashes)),
Arc::new(embedding_array),
],
)?;
Ok(batch)
}
pub fn document_block_to_batch(
&self,
blocks: &[DocumentBlock],
embeddings: &[Vec<f32>],
) -> Result<RecordBatch> {
if blocks.len() != embeddings.len() {
return Err(anyhow::anyhow!(
"Number of blocks and embeddings must match"
));
}
if blocks.is_empty() {
return Err(anyhow::anyhow!("Empty blocks array"));
}
for (i, embedding) in embeddings.iter().enumerate() {
if embedding.len() != self.vector_dim {
return Err(anyhow::anyhow!(
"Embedding at index {} has dimension {} but expected {}",
i,
embedding.len(),
self.vector_dim
));
}
}
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("path", DataType::Utf8, false),
Field::new("title", DataType::Utf8, false),
Field::new("content", DataType::Utf8, false),
Field::new(
"context",
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
true,
),
Field::new("level", DataType::UInt32, false),
Field::new("start_line", DataType::UInt32, false),
Field::new("end_line", DataType::UInt32, false),
Field::new("hash", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
),
true,
),
]));
let ids: Vec<String> = blocks.iter().map(|_| Uuid::new_v4().to_string()).collect();
let paths: Vec<&str> = blocks.iter().map(|b| b.path.as_str()).collect();
let titles: Vec<&str> = blocks.iter().map(|b| b.title.as_str()).collect();
let contents: Vec<&str> = blocks.iter().map(|b| b.content.as_str()).collect();
let mut context_values = Vec::new();
let mut context_offsets = vec![0i32];
for block in blocks {
for context_item in &block.context {
context_values.push(context_item.as_str());
}
context_offsets.push(context_values.len() as i32);
}
let context_array = ListArray::new(
Arc::new(Field::new("item", DataType::Utf8, true)),
arrow::buffer::OffsetBuffer::new(context_offsets.into()),
Arc::new(StringArray::from(context_values)),
None,
);
let levels: Vec<u32> = blocks.iter().map(|b| b.level as u32).collect();
let start_lines: Vec<u32> = blocks.iter().map(|b| b.start_line as u32).collect();
let end_lines: Vec<u32> = blocks.iter().map(|b| b.end_line as u32).collect();
let hashes: Vec<&str> = blocks.iter().map(|b| b.hash.as_str()).collect();
let mut flattened_embeddings = Vec::with_capacity(blocks.len() * self.vector_dim);
for embedding in embeddings {
flattened_embeddings.extend_from_slice(embedding);
}
let values = Float32Array::from(flattened_embeddings);
let embedding_array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
Arc::new(values),
None, );
let expected_len = blocks.len();
assert_eq!(ids.len(), expected_len, "ids array length mismatch");
assert_eq!(paths.len(), expected_len, "paths array length mismatch");
assert_eq!(titles.len(), expected_len, "titles array length mismatch");
assert_eq!(
contents.len(),
expected_len,
"contents array length mismatch"
);
assert_eq!(levels.len(), expected_len, "levels array length mismatch");
assert_eq!(
start_lines.len(),
expected_len,
"start_lines array length mismatch"
);
assert_eq!(
end_lines.len(),
expected_len,
"end_lines array length mismatch"
);
assert_eq!(hashes.len(), expected_len, "hashes array length mismatch");
assert_eq!(
embedding_array.len(),
expected_len,
"embedding array length mismatch"
);
assert_eq!(
embedding_array.len(),
expected_len,
"embedding array length mismatch"
);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids)),
Arc::new(StringArray::from(paths)),
Arc::new(StringArray::from(titles)),
Arc::new(StringArray::from(contents)),
Arc::new(context_array),
Arc::new(UInt32Array::from(levels)),
Arc::new(UInt32Array::from(start_lines)),
Arc::new(UInt32Array::from(end_lines)),
Arc::new(StringArray::from(hashes)),
Arc::new(embedding_array),
],
)?;
Ok(batch)
}
pub fn batch_to_code_blocks(
&self,
batch: &RecordBatch,
_embeddings: Option<&[Vec<f32>]>,
) -> Result<Vec<CodeBlock>> {
let mut code_blocks = Vec::new();
let _id_array = batch
.column_by_name("id")
.ok_or_else(|| anyhow::anyhow!("id column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("id column is not a StringArray"))?;
let path_array = batch
.column_by_name("path")
.ok_or_else(|| anyhow::anyhow!("path column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("path column is not a StringArray"))?;
let language_array = batch
.column_by_name("language")
.ok_or_else(|| anyhow::anyhow!("language column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("language column is not a StringArray"))?;
let content_array = batch
.column_by_name("content")
.ok_or_else(|| anyhow::anyhow!("content column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("content column is not a StringArray"))?;
let symbols_array = batch
.column_by_name("symbols")
.ok_or_else(|| anyhow::anyhow!("symbols column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("symbols column is not a StringArray"))?;
let start_line_array = batch
.column_by_name("start_line")
.ok_or_else(|| anyhow::anyhow!("start_line column not found"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow::anyhow!("start_line column is not a UInt32Array"))?;
let end_line_array = batch
.column_by_name("end_line")
.ok_or_else(|| anyhow::anyhow!("end_line column not found"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow::anyhow!("end_line column is not a UInt32Array"))?;
let hash_array = batch
.column_by_name("hash")
.ok_or_else(|| anyhow::anyhow!("hash column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("hash column is not a StringArray"))?;
let distance_array = batch
.column_by_name("_distance")
.and_then(|col| col.as_any().downcast_ref::<Float32Array>())
.map(|arr| (0..arr.len()).map(|i| arr.value(i)).collect::<Vec<f32>>())
.unwrap_or_default();
for i in 0..batch.num_rows() {
let symbols_json = symbols_array.value(i);
let symbols: Vec<String> = if symbols_json.is_empty() {
Vec::new()
} else {
serde_json::from_str(symbols_json).unwrap_or_default()
};
let code_block = CodeBlock {
path: path_array.value(i).to_string(),
language: language_array.value(i).to_string(),
content: content_array.value(i).to_string(),
symbols,
start_line: (start_line_array.value(i) as usize) + 1, end_line: (end_line_array.value(i) as usize) + 1, hash: hash_array.value(i).to_string(),
distance: distance_array.get(i).copied(),
};
code_blocks.push(code_block);
}
Ok(code_blocks)
}
pub fn batch_to_text_blocks(
&self,
batch: &RecordBatch,
_embeddings: Option<&[Vec<f32>]>,
) -> Result<Vec<TextBlock>> {
let mut text_blocks = Vec::new();
let path_array = batch
.column_by_name("path")
.ok_or_else(|| anyhow::anyhow!("path column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("path column is not a StringArray"))?;
let language_array = batch
.column_by_name("language")
.ok_or_else(|| anyhow::anyhow!("language column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("language column is not a StringArray"))?;
let content_array = batch
.column_by_name("content")
.ok_or_else(|| anyhow::anyhow!("content column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("content column is not a StringArray"))?;
let start_line_array = batch
.column_by_name("start_line")
.ok_or_else(|| anyhow::anyhow!("start_line column not found"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow::anyhow!("start_line column is not a UInt32Array"))?;
let end_line_array = batch
.column_by_name("end_line")
.ok_or_else(|| anyhow::anyhow!("end_line column not found"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow::anyhow!("end_line column is not a UInt32Array"))?;
let hash_array = batch
.column_by_name("hash")
.ok_or_else(|| anyhow::anyhow!("hash column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("hash column is not a StringArray"))?;
let distance_array = batch
.column_by_name("_distance")
.and_then(|col| col.as_any().downcast_ref::<Float32Array>())
.map(|arr| (0..arr.len()).map(|i| arr.value(i)).collect::<Vec<f32>>())
.unwrap_or_default();
for i in 0..batch.num_rows() {
let text_block = TextBlock {
path: path_array.value(i).to_string(),
language: language_array.value(i).to_string(),
content: content_array.value(i).to_string(),
start_line: (start_line_array.value(i) as usize) + 1, end_line: (end_line_array.value(i) as usize) + 1, hash: hash_array.value(i).to_string(),
distance: distance_array.get(i).copied(),
};
text_blocks.push(text_block);
}
Ok(text_blocks)
}
pub fn batch_to_document_blocks(
&self,
batch: &RecordBatch,
_embeddings: Option<&[Vec<f32>]>,
) -> Result<Vec<DocumentBlock>> {
let mut document_blocks = Vec::new();
let path_array = batch
.column_by_name("path")
.ok_or_else(|| anyhow::anyhow!("path column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("path column is not a StringArray"))?;
let title_array = batch
.column_by_name("title")
.ok_or_else(|| anyhow::anyhow!("title column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("title column is not a StringArray"))?;
let content_array = batch
.column_by_name("content")
.ok_or_else(|| anyhow::anyhow!("content column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("content column is not a StringArray"))?;
let context_array = batch
.column_by_name("context")
.ok_or_else(|| anyhow::anyhow!("context column not found"))?
.as_any()
.downcast_ref::<ListArray>()
.ok_or_else(|| anyhow::anyhow!("context column is not a ListArray"))?;
let start_line_array = batch
.column_by_name("start_line")
.ok_or_else(|| anyhow::anyhow!("start_line column not found"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow::anyhow!("start_line column is not a UInt32Array"))?;
let end_line_array = batch
.column_by_name("end_line")
.ok_or_else(|| anyhow::anyhow!("end_line column not found"))?
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| anyhow::anyhow!("end_line column is not a UInt32Array"))?;
let hash_array = batch
.column_by_name("hash")
.ok_or_else(|| anyhow::anyhow!("hash column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("hash column is not a StringArray"))?;
let distance_array = batch
.column_by_name("_distance")
.and_then(|col| col.as_any().downcast_ref::<Float32Array>())
.map(|arr| (0..arr.len()).map(|i| arr.value(i)).collect::<Vec<f32>>())
.unwrap_or_default();
for i in 0..batch.num_rows() {
let context = if context_array.is_null(i) {
Vec::new()
} else {
let list_values = context_array.value(i);
let string_array = list_values
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("Context list items are not strings"))?;
let mut context_vec = Vec::new();
for j in 0..string_array.len() {
if !string_array.is_null(j) {
context_vec.push(string_array.value(j).to_string());
}
}
context_vec
};
let level = if let Some(level_col) = batch.column_by_name("level") {
if let Some(level_array) = level_col.as_any().downcast_ref::<UInt32Array>() {
level_array.value(i) as usize
} else {
0 }
} else {
0 };
let document_block = DocumentBlock {
path: path_array.value(i).to_string(),
title: title_array.value(i).to_string(),
content: content_array.value(i).to_string(),
context,
level,
start_line: (start_line_array.value(i) as usize) + 1, end_line: (end_line_array.value(i) as usize) + 1, hash: hash_array.value(i).to_string(),
distance: distance_array.get(i).copied(),
};
document_blocks.push(document_block);
}
Ok(document_blocks)
}
pub fn commit_block_to_batch(
&self,
blocks: &[CommitBlock],
embeddings: &[Vec<f32>],
) -> Result<RecordBatch> {
if blocks.len() != embeddings.len() {
return Err(anyhow::anyhow!(
"Number of blocks and embeddings must match"
));
}
if blocks.is_empty() {
return Err(anyhow::anyhow!("Empty blocks array"));
}
for (i, embedding) in embeddings.iter().enumerate() {
if embedding.len() != self.vector_dim {
return Err(anyhow::anyhow!(
"Embedding at index {} has dimension {} but expected {}",
i,
embedding.len(),
self.vector_dim
));
}
}
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("hash", DataType::Utf8, false),
Field::new("author", DataType::Utf8, false),
Field::new("date", DataType::Int64, false),
Field::new("message", DataType::Utf8, false),
Field::new("content", DataType::Utf8, false),
Field::new("files", DataType::Utf8, false),
Field::new("description", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
),
true,
),
]));
let ids: Vec<String> = blocks.iter().map(|_| Uuid::new_v4().to_string()).collect();
let hashes: Vec<&str> = blocks.iter().map(|b| b.hash.as_str()).collect();
let authors: Vec<&str> = blocks.iter().map(|b| b.author.as_str()).collect();
let dates: Vec<i64> = blocks.iter().map(|b| b.date).collect();
let messages: Vec<&str> = blocks.iter().map(|b| b.message.as_str()).collect();
let contents: Vec<&str> = blocks.iter().map(|b| b.content.as_str()).collect();
let files: Vec<&str> = blocks.iter().map(|b| b.files.as_str()).collect();
let descriptions: Vec<&str> = blocks.iter().map(|b| b.description.as_str()).collect();
let mut flattened_embeddings = Vec::with_capacity(blocks.len() * self.vector_dim);
for embedding in embeddings {
flattened_embeddings.extend_from_slice(embedding);
}
let values = Float32Array::from(flattened_embeddings);
let embedding_array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float32, true)),
self.vector_dim as i32,
Arc::new(values),
None,
);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids)),
Arc::new(StringArray::from(hashes)),
Arc::new(StringArray::from(authors)),
Arc::new(Int64Array::from(dates)),
Arc::new(StringArray::from(messages)),
Arc::new(StringArray::from(contents)),
Arc::new(StringArray::from(files)),
Arc::new(StringArray::from(descriptions)),
Arc::new(embedding_array),
],
)?;
Ok(batch)
}
pub fn batch_to_commit_blocks(&self, batch: &RecordBatch) -> Result<Vec<CommitBlock>> {
let mut commit_blocks = Vec::new();
let hash_array = batch
.column_by_name("hash")
.ok_or_else(|| anyhow::anyhow!("hash column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("hash column is not a StringArray"))?;
let author_array = batch
.column_by_name("author")
.ok_or_else(|| anyhow::anyhow!("author column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("author column is not a StringArray"))?;
let date_array = batch
.column_by_name("date")
.ok_or_else(|| anyhow::anyhow!("date column not found"))?
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| anyhow::anyhow!("date column is not an Int64Array"))?;
let message_array = batch
.column_by_name("message")
.ok_or_else(|| anyhow::anyhow!("message column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("message column is not a StringArray"))?;
let content_array = batch
.column_by_name("content")
.ok_or_else(|| anyhow::anyhow!("content column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("content column is not a StringArray"))?;
let files_array = batch
.column_by_name("files")
.ok_or_else(|| anyhow::anyhow!("files column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("files column is not a StringArray"))?;
let description_array = batch
.column_by_name("description")
.ok_or_else(|| anyhow::anyhow!("description column not found"))?
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| anyhow::anyhow!("description column is not a StringArray"))?;
let distance_array = batch
.column_by_name("_distance")
.and_then(|col| col.as_any().downcast_ref::<Float32Array>())
.map(|arr| (0..arr.len()).map(|i| arr.value(i)).collect::<Vec<f32>>())
.unwrap_or_default();
for i in 0..batch.num_rows() {
commit_blocks.push(CommitBlock {
hash: hash_array.value(i).to_string(),
author: author_array.value(i).to_string(),
date: date_array.value(i),
message: message_array.value(i).to_string(),
content: content_array.value(i).to_string(),
files: files_array.value(i).to_string(),
description: description_array.value(i).to_string(),
distance: distance_array.get(i).copied(),
});
}
Ok(commit_blocks)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_commit_block_batch_roundtrip() {
let blocks = vec![
CommitBlock {
hash: "abc12345deadbeef1234567890abcdef12345678".to_string(),
author: "Alice".to_string(),
date: 1700000000,
message: "feat: add retry logic".to_string(),
content: "feat: add retry logic\n\nFiles: src/client.rs".to_string(),
files: r#"["src/client.rs"]"#.to_string(),
description: "Adds retry support".to_string(),
distance: None,
},
CommitBlock {
hash: "def67890deadbeef1234567890abcdef12345678".to_string(),
author: "Bob".to_string(),
date: 1700001000,
message: "fix: resolve crash".to_string(),
content: "fix: resolve crash\n\nFiles: src/main.rs".to_string(),
files: r#"["src/main.rs"]"#.to_string(),
description: String::new(),
distance: None,
},
];
let dim = 4;
let embeddings = vec![vec![0.1, 0.2, 0.3, 0.4], vec![0.5, 0.6, 0.7, 0.8]];
let converter = BatchConverter::new(dim);
let batch = converter
.commit_block_to_batch(&blocks, &embeddings)
.unwrap();
assert_eq!(batch.num_rows(), 2);
let result = converter.batch_to_commit_blocks(&batch).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].hash, blocks[0].hash);
assert_eq!(result[0].author, "Alice");
assert_eq!(result[0].date, 1700000000);
assert_eq!(result[0].message, "feat: add retry logic");
assert_eq!(result[0].files, r#"["src/client.rs"]"#);
assert_eq!(result[0].description, "Adds retry support");
assert_eq!(result[1].hash, blocks[1].hash);
assert_eq!(result[1].author, "Bob");
assert_eq!(result[1].description, "");
}
#[test]
fn test_commit_block_batch_empty_errors() {
let converter = BatchConverter::new(4);
let result = converter.commit_block_to_batch(&[], &[]);
assert!(result.is_err());
}
#[test]
fn test_commit_block_batch_dimension_mismatch() {
let blocks = vec![CommitBlock {
hash: "abc".to_string(),
author: "A".to_string(),
date: 0,
message: "m".to_string(),
content: "c".to_string(),
files: "[]".to_string(),
description: String::new(),
distance: None,
}];
let converter = BatchConverter::new(4);
let wrong_dim_embeddings = vec![vec![0.1, 0.2]]; let result = converter.commit_block_to_batch(&blocks, &wrong_dim_embeddings);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("dimension"));
}
}