use crate::core::df::tabular::write_df_parquet;
use crate::model::merkle_tree::node::FileNodeWithDir;
use crate::view::data_frames::columns::NewColumn;
use polars::frame::DataFrame;
use sql_query_builder::Select;
use crate::constants::{MODS_DIR, OXEN_HIDDEN_DIR};
use crate::constants::{OXEN_COLS, TABLE_NAME};
use crate::core;
use crate::core::db::data_frames::df_db::with_df_db_manager;
use crate::core::db::data_frames::workspace_df_db::select_cols_from_schema;
use crate::core::db::data_frames::{df_db, workspace_df_db};
use crate::core::df::sql;
use crate::core::versions::MinOxenVersion;
use crate::error::OxenError;
use crate::model::{Branch, Commit, EntryDataType, LocalRepository, NewCommitBody, Workspace};
use crate::opts::DFOpts;
use crate::{repositories, util};
use crate::model::diff::tabular_diff::{
TabularDiffDupes, TabularDiffMods, TabularDiffParameters, TabularDiffSchemas,
TabularDiffSummary, TabularSchemaDiff,
};
use crate::model::diff::{AddRemoveModifyCounts, DiffResult, TabularDiff};
use crate::core::db::data_frames::columns::polar_insert_column;
use duckdb::arrow::array::RecordBatch;
use std::path::{Path, PathBuf};
pub mod columns;
pub mod embeddings;
pub mod rows;
pub mod schemas;
pub fn is_behind(workspace: &Workspace, path: impl AsRef<Path>) -> Result<bool, OxenError> {
let commit_path = previous_commit_ref_path(workspace, path);
let commit_id = util::fs::read_from_path(commit_path)?;
Ok(commit_id != workspace.commit.id)
}
pub fn is_indexed(workspace: &Workspace, path: impl AsRef<Path>) -> Result<bool, OxenError> {
let path = path.as_ref();
log::debug!("checking dataset is indexed for {path:?}");
let db_path = duckdb_path(workspace, path);
log::debug!("getting conn at path {db_path:?}");
with_df_db_manager(db_path, |manager| {
manager.with_conn(|conn| {
let table_exists = df_db::table_exists(conn, TABLE_NAME)?;
log::debug!("dataset_is_indexed() got table_exists: {table_exists:?}");
Ok(table_exists)
})
})
}
pub fn is_queryable_data_frame_indexed(
repo: &LocalRepository,
path: impl AsRef<Path>,
commit: &Commit,
) -> Result<bool, OxenError> {
match repo.min_version() {
MinOxenVersion::V0_10_0 => panic!("v0.10.0 no longer supported"),
_ => core::v_latest::workspaces::data_frames::is_queryable_data_frame_indexed(
repo, commit, path,
),
}
}
pub fn get_queryable_data_frame_workspace(
repo: &LocalRepository,
path: impl AsRef<Path>,
commit: &Commit,
) -> Result<Workspace, OxenError> {
match repo.min_version() {
MinOxenVersion::V0_10_0 => {
panic!("get_queryable_data_frame_workspace not implemented for v0.10.0");
}
_ => core::v_latest::workspaces::data_frames::get_queryable_data_frame_workspace(
repo, path, commit,
),
}
}
pub async fn index(
repo: &LocalRepository,
workspace: &Workspace,
path: impl AsRef<Path>,
) -> Result<(), OxenError> {
match repo.min_version() {
MinOxenVersion::V0_10_0 => panic!("v0.10.0 no longer supported"),
_ => core::v_latest::workspaces::data_frames::index(workspace, path.as_ref()).await,
}
}
pub async fn rename(
workspace: &Workspace,
path: impl AsRef<Path>,
new_path: impl AsRef<Path>,
) -> Result<PathBuf, OxenError> {
match workspace.base_repo.min_version() {
MinOxenVersion::V0_10_0 => Err(OxenError::basic_str(
"rename is not supported for this version of oxen",
)),
_ => core::v_latest::workspaces::data_frames::rename(workspace, path, new_path).await,
}
}
pub fn unindex(workspace: &Workspace, path: impl AsRef<Path>) -> Result<(), OxenError> {
let path = path.as_ref();
let db_path = repositories::workspaces::data_frames::duckdb_path(workspace, path);
with_df_db_manager(db_path, |manager| {
manager.with_conn(|conn| {
df_db::drop_table(conn, TABLE_NAME)?;
Ok(())
})
})
}
pub async fn restore(
repo: &LocalRepository,
workspace: &Workspace,
path: impl AsRef<Path>,
) -> Result<(), OxenError> {
unindex(workspace, &path)?;
index(repo, workspace, path.as_ref()).await?;
Ok(())
}
pub fn count(workspace: &Workspace, path: impl AsRef<Path>) -> Result<usize, OxenError> {
let db_path = repositories::workspaces::data_frames::duckdb_path(workspace, path);
with_df_db_manager(db_path, |manager| {
manager.with_conn(|conn| {
let count = df_db::count(conn, TABLE_NAME)?;
Ok(count)
})
})
}
pub fn query(
workspace: &Workspace,
path: impl AsRef<Path>,
opts: &DFOpts,
) -> Result<DataFrame, OxenError> {
let path = path.as_ref();
let db_path = repositories::workspaces::data_frames::duckdb_path(workspace, path);
log::debug!("query_staged_df() got db_path: {db_path:?}");
log::debug!("query() opts: {opts:?}");
with_df_db_manager(db_path, |manager| {
manager.with_conn_mut(|conn| {
let schema = df_db::get_schema(conn, TABLE_NAME)?;
let col_names = select_cols_from_schema(&schema)?;
let df = if let Some(embedding_opts) = opts.get_sort_by_embedding_query() {
log::debug!("querying embeddings: {embedding_opts:?}");
repositories::workspaces::data_frames::embeddings::query_with_conn(
conn,
workspace,
&embedding_opts,
)?
} else if let Some(sql) = &opts.sql {
log::debug!("querying sql: {sql:?}");
return sql::query_df(conn, sql.clone(), None);
} else {
log::debug!("querying select cols: {col_names:?}");
let select = Select::new().select(&col_names).from(TABLE_NAME);
df_db::select(conn, &select, Some(opts))?
};
Ok(df)
})
})
}
pub fn export(
workspace: &Workspace,
path: impl AsRef<Path>,
opts: &DFOpts,
temp_file: impl AsRef<Path>,
) -> Result<(), OxenError> {
let path = path.as_ref();
let db_path = repositories::workspaces::data_frames::duckdb_path(workspace, path);
log::debug!("export() got db_path: {db_path:?}");
with_df_db_manager(db_path, |manager| {
manager.with_conn(|conn| {
let sql = if let Some(embedding_opts) = opts.get_sort_by_embedding_query() {
let exclude_cols = true;
repositories::workspaces::data_frames::embeddings::similarity_query_with_conn(
conn,
workspace,
&embedding_opts,
exclude_cols,
)?
} else if let Some(sql) = opts.sql.clone() {
add_exclude_to_sql(&sql)?
} else {
let sql = format!("SELECT * FROM {TABLE_NAME}");
add_exclude_to_sql(&sql)?
};
log::debug!("exporting data frame with sql: {sql:?}");
sql::export_df(conn, sql, Some(opts), temp_file)?;
Ok(())
})
})
}
pub fn diff(workspace: &Workspace, path: impl AsRef<Path>) -> Result<DataFrame, OxenError> {
let file_path = path.as_ref();
let staged_db_path = repositories::workspaces::data_frames::duckdb_path(workspace, file_path);
with_df_db_manager(staged_db_path, |manager| {
manager.with_conn(|conn| {
let diff_df = workspace_df_db::df_diff(conn)?;
Ok(diff_df)
})
})
}
pub fn full_diff(workspace: &Workspace, path: impl AsRef<Path>) -> Result<DiffResult, OxenError> {
let repo = &workspace.base_repo;
let path = path.as_ref();
log::debug!("diff_workspace_df got repo at path {:?}", repo.path);
if !is_indexed(workspace, path)? {
return Err(OxenError::basic_str("Dataset is not indexed"));
};
let db_path = repositories::workspaces::data_frames::duckdb_path(workspace, path);
with_df_db_manager(db_path, |manager| {
manager.with_conn(|conn| {
let diff_df = workspace_df_db::df_diff(conn)?;
log::debug!("full_diff() diff_df: {diff_df:?}");
if diff_df.is_empty() {
return Ok(DiffResult::Tabular(TabularDiff::empty()));
}
let row_mods = AddRemoveModifyCounts::from_diff_df(&diff_df)?;
let schema = workspace_df_db::schema_without_oxen_cols(conn, TABLE_NAME)?;
let schemas = TabularDiffSchemas {
left: schema.clone(),
right: schema.clone(),
diff: schema.clone(),
};
let diff_summary = TabularDiffSummary {
modifications: TabularDiffMods {
row_counts: row_mods,
col_changes: TabularSchemaDiff::empty(),
},
schemas,
dupes: TabularDiffDupes::empty(),
};
let diff_result = TabularDiff {
contents: diff_df,
parameters: TabularDiffParameters::empty(),
summary: diff_summary,
filename1: None,
filename2: None,
};
Ok(DiffResult::Tabular(diff_result))
})
})
}
#[allow(clippy::too_many_arguments)]
pub async fn from_directory(
repo: &LocalRepository,
workspace: &Workspace,
path: impl AsRef<Path>,
output_path: impl AsRef<Path>,
extra_columns: &[NewColumn],
recursive: bool,
new_commit: &NewCommitBody,
branch: &Branch,
) -> Result<Commit, OxenError> {
let has_dir = repositories::tree::has_dir(repo, &workspace.commit, path.as_ref())?;
if !has_dir {
return Err(OxenError::basic_str(format!(
"Directory not found: {:?}",
path.as_ref()
)));
}
let depth = if recursive { -1 } else { 1 };
let subtree = repositories::tree::get_subtree(
repo,
&workspace.commit,
&Some(path.as_ref().to_path_buf()),
&Some(depth),
)?;
let files =
repositories::tree::list_all_files(&subtree.unwrap(), &path.as_ref().to_path_buf())?;
let file_paths: Vec<String> = files
.iter()
.map(|file_with_dir| {
file_with_dir
.dir
.join(file_with_dir.file_node.name())
.to_string_lossy()
.to_string()
})
.collect();
let db_path = workspace.dir().join("temp_file_listing.db");
let mut df = with_df_db_manager(&db_path, |manager| {
manager.with_conn(|conn| {
conn.execute("CREATE TABLE file_listing (file_path VARCHAR);", [])?;
if !file_paths.is_empty() {
let values_clause = file_paths
.iter()
.map(|_| "(?)")
.collect::<Vec<_>>()
.join(", ");
let bulk_insert_sql =
format!("INSERT INTO file_listing (file_path) VALUES {values_clause}");
let params: Vec<&dyn duckdb::ToSql> = file_paths
.iter()
.map(|path| path as &dyn duckdb::ToSql)
.collect();
conn.execute(&bulk_insert_sql, params.as_slice())?;
}
for new_column in extra_columns {
polar_insert_column(conn, "file_listing", new_column)?;
}
let sql_query = "SELECT * FROM file_listing";
let result_set: Vec<RecordBatch> = conn.prepare(sql_query)?.query_arrow([])?.collect();
df_db::record_batches_to_polars_df(result_set)
})
})?;
let output_path = workspace.dir().join(output_path);
let output_path = if output_path
.extension()
.and_then(|ext| ext.to_str())
.is_some_and(|ext| ext.eq_ignore_ascii_case("parquet"))
{
output_path
} else {
output_path.with_extension("parquet")
};
if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent)?;
}
write_df_parquet(&mut df, &output_path)?;
repositories::workspaces::files::add(workspace, &output_path).await?;
let files_vec: Vec<FileNodeWithDir> = files.iter().cloned().collect();
set_media_render_metadata_if_applicable(repo, workspace, &files_vec, &output_path).await?;
let commit =
repositories::workspaces::commit(workspace, new_commit, branch.name.as_str()).await?;
println!(
"Created parquet file with {} file paths at: {:?}",
file_paths.len(),
output_path
);
Ok(commit)
}
fn render_function_for_media_type(data_type: &EntryDataType) -> Option<&'static str> {
match data_type {
EntryDataType::Image => Some("image"),
EntryDataType::Video => Some("video"),
_ => None,
}
}
fn get_uniform_media_type(files: &[FileNodeWithDir]) -> Option<&'static str> {
if files.is_empty() {
return None;
}
let first_data_type = files[0].file_node.data_type();
let render_func = render_function_for_media_type(first_data_type)?;
let all_match = files
.iter()
.all(|file_with_dir| file_with_dir.file_node.data_type() == first_data_type);
if all_match { Some(render_func) } else { None }
}
pub async fn set_media_render_metadata_if_applicable(
repo: &LocalRepository,
workspace: &Workspace,
files: &[FileNodeWithDir],
output_path: impl AsRef<Path>,
) -> Result<(), OxenError> {
if let Some(render_func) = get_uniform_media_type(files) {
let render_metadata = serde_json::json!({
"_oxen": {
"render": {
"func": render_func
}
}
});
repositories::workspaces::data_frames::columns::add_column_metadata(
repo,
workspace,
output_path.as_ref().to_path_buf(),
"file_path".to_string(),
&render_metadata,
)?;
}
Ok(())
}
pub fn duckdb_path(workspace: &Workspace, path: impl AsRef<Path>) -> PathBuf {
let path = util::fs::linux_path(path.as_ref());
log::debug!(
"duckdb_path path: {:?} workspace: {:?}",
path,
workspace.dir()
);
let path_hash = util::hasher::hash_str(path.to_string_lossy());
workspace
.dir()
.join(OXEN_HIDDEN_DIR)
.join(MODS_DIR)
.join("duckdb")
.join(path_hash)
.join("db")
}
pub fn previous_commit_ref_path(workspace: &Workspace, path: impl AsRef<Path>) -> PathBuf {
let path_hash = util::hasher::hash_str(path.as_ref().to_string_lossy());
workspace
.dir()
.join(OXEN_HIDDEN_DIR)
.join(MODS_DIR)
.join("duckdb")
.join(path_hash)
.join("COMMIT_ID")
}
pub fn column_changes_path(workspace: &Workspace, path: impl AsRef<Path>) -> PathBuf {
let path_hash = util::hasher::hash_str(path.as_ref().to_string_lossy());
workspace
.dir()
.join(OXEN_HIDDEN_DIR)
.join(MODS_DIR)
.join("duckdb")
.join(path_hash)
.join("column_changes")
}
pub fn row_changes_path(workspace: &Workspace, path: impl AsRef<Path>) -> PathBuf {
let path_hash = util::hasher::hash_str(path.as_ref().to_string_lossy());
workspace
.dir()
.join(OXEN_HIDDEN_DIR)
.join(MODS_DIR)
.join("duckdb")
.join(path_hash)
.join("row_changes")
}
fn add_exclude_to_sql(sql: &str) -> Result<String, OxenError> {
let excluded_cols = OXEN_COLS
.iter()
.map(|col| format!("\"{col}\""))
.collect::<Vec<String>>()
.join(", ");
let select_idx = sql
.to_lowercase()
.find("select")
.ok_or_else(|| OxenError::basic_str("No SELECT found in query"))?;
let from_idx = sql[select_idx..]
.to_lowercase()
.find("from")
.ok_or_else(|| OxenError::basic_str("No FROM found in query"))?;
let before_from = &sql[..select_idx + from_idx];
let after_from = &sql[select_idx + from_idx..];
let has_group_by = sql.to_lowercase().contains("group by");
let modified_select = if has_group_by {
before_from.trim().to_string()
} else if before_from.trim().to_lowercase().ends_with("select *") {
format!("SELECT * EXCLUDE ({excluded_cols})")
} else {
let (select_part, columns_part) = before_from.split_at(select_idx + "select".len());
let columns = columns_part.trim();
format!("{select_part} {columns} EXCLUDE ({excluded_cols})")
};
Ok(format!("{modified_select} {after_from}"))
}
#[cfg(test)]
mod tests {
use std::path::Path;
use serde_json::json;
use super::*;
use crate::config::UserConfig;
use crate::constants::{DEFAULT_BRANCH_NAME, OXEN_ID_COL};
use crate::core::df;
use crate::error::OxenError;
use crate::model::NewCommitBody;
use crate::model::diff::DiffResult;
use crate::opts::DFOpts;
use crate::repositories::workspaces;
use crate::test;
use crate::{repositories, util};
#[tokio::test]
async fn test_add_row() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let json_data = json!({
"file": "dawg1.jpg",
"label": "dog",
"min_x": 13,
"min_y": 14,
"width": 100,
"height": 100
});
workspaces::data_frames::rows::add(&repo, &workspace, &file_path, &json_data)?;
let status = workspaces::status::status(&workspace)?;
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(added_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_delete_added_row_with_two_rows() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let json_data = json!({
"file": "dawg1.jpg",
"label": "dog",
"min_x": 13,
"min_y": 14,
"width": 100,
"height": 100
});
let append_entry_1 =
workspaces::data_frames::rows::add(&repo, &workspace, &file_path, &json_data)?;
let append_1_id = append_entry_1.column(OXEN_ID_COL)?.get(0)?.to_string();
let append_1_id = append_1_id.replace('"', "");
let json_data = json!({
"file": "dawg2.jpg",
"label": "dog",
"min_x": 13,
"min_y": 14,
"width": 100,
"height": 100
});
let _append_entry_2 =
workspaces::data_frames::rows::add(&repo, &workspace, &file_path, &json_data)?;
let status = workspaces::status::status(&workspace)?;
log::debug!("status is {status:?}");
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(added_rows, 2);
}
_ => panic!("Expected tabular diff result"),
}
workspaces::data_frames::rows::delete(&repo, &workspace, &file_path, &append_1_id)?;
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(added_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_clear_changes() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let json_data = json!({
"file": "dawg1.jpg",
"label": "dog",
"min_x": 13,
"min_y": 14,
"width": 100,
"height": 100
});
let append_entry_1 =
workspaces::data_frames::rows::add(&repo, &workspace, &file_path, &json_data)?;
let append_1_id = append_entry_1.column(OXEN_ID_COL)?.get(0)?;
let append_1_id = append_1_id.get_str().unwrap();
log::debug!("added the row");
let json_data = json!({
"file": "dawg2.jpg",
"label": "dog",
"min_x": 13,
"min_y": 14,
"width": 100,
"height": 100
});
let append_entry_2 =
workspaces::data_frames::rows::add(&repo, &workspace, &file_path, &json_data)?;
let append_2_id = append_entry_2.column(OXEN_ID_COL)?.get(0)?;
let append_2_id = append_2_id.get_str().unwrap();
let status = workspaces::status::status(&workspace)?;
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(added_rows, 2);
}
_ => panic!("Expected tabular diff result"),
}
workspaces::data_frames::rows::delete(&repo, &workspace, &file_path, append_1_id)?;
workspaces::data_frames::rows::delete(&repo, &workspace, &file_path, append_2_id)?;
let status = workspaces::status::status(&workspace)?;
assert_eq!(status.staged_files.len(), 0);
log::debug!("about to diff staged");
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
log::debug!("got diff staged");
match diff {
DiffResult::Tabular(tabular_diff) => {
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(added_rows, 0);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_delete_committed_row() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let mut page_opts = DFOpts::empty();
page_opts.page = Some(0);
page_opts.page_size = Some(10);
let staged_df = workspaces::data_frames::query(&workspace, &file_path, &page_opts)?;
let id_to_delete = staged_df.column(OXEN_ID_COL)?.get(0)?.to_string();
let id_to_delete = id_to_delete.replace('"', "");
workspaces::data_frames::rows::delete(&repo, &workspace, &file_path, &id_to_delete)?;
let status = workspaces::status::status(&workspace)?;
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let removed_rows = tabular_diff.summary.modifications.row_counts.removed;
assert_eq!(removed_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
let status = repositories::status(&repo).await?;
log::debug!("got this status {status:?}");
let new_commit = NewCommitBody {
author: "author".to_string(),
email: "email".to_string(),
message: "Deleting a row allegedly".to_string(),
};
let commit_2 = workspaces::commit(&workspace, &new_commit, branch_name).await?;
let file_1 = repositories::revisions::get_version_file_from_commit_id(
&repo, &commit.id, &file_path,
)
.await?;
let file_1_csv = file_1.with_extension("csv");
util::fs::copy(&*file_1, &file_1_csv)?;
log::debug!("copied file 1 to {file_1_csv:?}");
let file_2 = repositories::revisions::get_version_file_from_commit_id(
&repo,
commit_2.id,
&file_path,
)
.await?;
let file_2_csv = file_2.with_extension("csv");
util::fs::copy(&*file_2, &file_2_csv)?;
log::debug!("copied file 2 to {file_2_csv:?}");
let diff_result =
repositories::diffs::diff_files(file_1_csv, file_2_csv, vec![], vec![], vec![])
.await?;
log::debug!("diff result is {diff_result:?}");
match diff_result {
DiffResult::Tabular(tabular_diff) => {
let removed_rows = tabular_diff.summary.modifications.row_counts.removed;
assert_eq!(removed_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_modify_added_row() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let json_data = json!({
"min_x": 13,
"min_y": 14,
"width": 100,
"height": 100
});
let new_row =
workspaces::data_frames::rows::add(&repo, &workspace, &file_path, &json_data)?;
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(added_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
let id_to_modify = new_row.column(OXEN_ID_COL)?.get(0)?;
let id_to_modify = id_to_modify.get_str().unwrap();
let json_data = json!({
"height": 101
});
workspaces::data_frames::rows::update(
&repo,
&workspace,
&file_path,
id_to_modify,
&json_data,
)?;
let status = workspaces::status::status(&workspace)?;
log::debug!("found mod entries: {status:?}");
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let modified_rows = tabular_diff.summary.modifications.row_counts.modified;
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(modified_rows, 0);
assert_eq!(added_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_delete_added_single_row() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let json_data = json!({
"min_x": 13,
"min_y": 14,
"width": 100,
"height": 100
});
let new_row =
workspaces::data_frames::rows::add(&repo, &workspace, &file_path, &json_data)?;
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let added_rows = tabular_diff.summary.modifications.row_counts.added;
assert_eq!(added_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
let id_to_delete = new_row.column(OXEN_ID_COL)?.get(0)?.to_string();
let id_to_delete = id_to_delete.replace('"', "");
workspaces::data_frames::rows::delete(&repo, &workspace, &file_path, &id_to_delete)?;
log::debug!("done deleting row");
let status = workspaces::status::status(&workspace)?;
log::debug!("found mod entries: {status:?}");
assert_eq!(status.staged_files.len(), 0);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let removed_rows = tabular_diff.summary.modifications.row_counts.removed;
assert_eq!(removed_rows, 0);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_modify_row_back_to_original_state() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let mut page_opts = DFOpts::empty();
page_opts.page = Some(0);
page_opts.page_size = Some(10);
let staged_df = workspaces::data_frames::query(&workspace, &file_path, &page_opts)?;
let id_to_modify = staged_df.column(OXEN_ID_COL)?.get(0)?.to_string();
let id_to_modify = id_to_modify.replace('"', "");
let json_data = json!({
"label": "doggo"
});
workspaces::data_frames::rows::update(
&repo,
&workspace,
&file_path,
&id_to_modify,
&json_data,
)?;
let status = workspaces::status::status(&workspace)?;
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let modified_rows = tabular_diff.summary.modifications.row_counts.modified;
assert_eq!(modified_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
let json_data = json!({
"label": "dog"
});
let res = workspaces::data_frames::rows::update(
&repo,
&workspace,
&file_path,
&id_to_modify,
&json_data,
)?;
log::debug!("res is... {res:?}");
let status = workspaces::status::status(&workspace)?;
assert_eq!(status.staged_files.len(), 0);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let modified_rows = tabular_diff.summary.modifications.row_counts.modified;
assert_eq!(modified_rows, 0);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_restore_row_after_modification() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let mut page_opts = DFOpts::empty();
page_opts.page = Some(0);
page_opts.page_size = Some(10);
let staged_df = workspaces::data_frames::query(&workspace, &file_path, &page_opts)?;
let id_to_modify = staged_df.column(OXEN_ID_COL)?.get(0)?.to_string();
let id_to_modify = id_to_modify.replace('"', "");
let json_data = json!({
"label": "doggo"
});
workspaces::data_frames::rows::update(
&repo,
&workspace,
&file_path,
&id_to_modify,
&json_data,
)?;
let status = workspaces::status::status(&workspace)?;
println!("status: {status:?}");
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let modified_rows = tabular_diff.summary.modifications.row_counts.modified;
assert_eq!(modified_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
let res = workspaces::data_frames::rows::restore(
&repo,
&workspace,
&file_path,
&id_to_modify,
)
.await?;
log::debug!("res is... {res:?}");
let status = workspaces::status::status(&workspace)?;
assert_eq!(status.staged_files.len(), 0);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let modified_rows = tabular_diff.summary.modifications.row_counts.modified;
let added_rows = tabular_diff.summary.modifications.row_counts.added;
let removed_rows = tabular_diff.summary.modifications.row_counts.removed;
assert_eq!(modified_rows, 0);
assert_eq!(added_rows, 0);
assert_eq!(removed_rows, 0);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_restore_row_delete() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch_name = "test-append";
let branch = repositories::branches::create_checkout(&repo, branch_name)?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let mut page_opts = DFOpts::empty();
page_opts.page = Some(0);
page_opts.page_size = Some(10);
let staged_df = workspaces::data_frames::query(&workspace, &file_path, &page_opts)?;
let id_to_delete = staged_df.column(OXEN_ID_COL)?.get(0)?.to_string();
let id_to_delete = id_to_delete.replace('"', "");
workspaces::data_frames::rows::delete(&repo, &workspace, &file_path, &id_to_delete)?;
let status = workspaces::status::status(&workspace)?;
println!("status: {status:?}");
assert_eq!(status.staged_files.len(), 1);
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let removed_rows = tabular_diff.summary.modifications.row_counts.removed;
assert_eq!(removed_rows, 1);
}
_ => panic!("Expected tabular diff result"),
}
workspaces::data_frames::rows::restore(&repo, &workspace, &file_path, &id_to_delete)
.await?;
let status = workspaces::status::status(&workspace)?;
println!("status: {status:?}");
assert!(status.is_clean());
let diff = workspaces::diff(&repo, &workspace, &file_path)?;
match diff {
DiffResult::Tabular(tabular_diff) => {
let modified_rows = tabular_diff.summary.modifications.row_counts.modified;
let added_rows = tabular_diff.summary.modifications.row_counts.added;
let removed_rows = tabular_diff.summary.modifications.row_counts.removed;
assert_eq!(modified_rows, 0);
assert_eq!(added_rows, 0);
assert_eq!(removed_rows, 0);
}
_ => panic!("Expected tabular diff result"),
}
Ok(())
})
.await
}
#[tokio::test]
async fn test_commit_tabular_append_invalid_column() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
let branch = repositories::branches::current_branch(&repo)?.unwrap();
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
workspaces::data_frames::index(&repo, &workspace, &path).await?;
let json_data = json!({"NOT_REAL_COLUMN": "images/test.jpg"});
let result = workspaces::data_frames::rows::add(&repo, &workspace, &path, &json_data);
assert!(result.is_err());
Ok(())
})
.await
}
#[tokio::test]
async fn test_commit_tabular_appends_staged() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move{
let path = Path::new("annotations")
.join("train")
.join("bounding_box.csv");
let commit = repositories::commits::head_commit(&repo)?;
let user = UserConfig::get()?.to_user();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
workspaces::data_frames::index(&repo, &workspace, &path).await?;
let json_data = json!({"file": "images/test.jpg", "label": "dog", "min_x": 2.0, "min_y": 3.0, "width": 100, "height": 120});
workspaces::data_frames::rows::add(&repo, &workspace, &path, &json_data)?;
let new_commit = NewCommitBody {
author: user.name.to_owned(),
email: user.email,
message: "Appending tabular data".to_string(),
};
let commit = workspaces::commit(&workspace, &new_commit, DEFAULT_BRANCH_NAME).await?;
let entry = repositories::entries::get_commit_entry(&repo, &commit, &path)?.unwrap();
let version_store = repo.version_store();
let version_file = version_store.get_version_path(&entry.hash).await?;
let extension = entry.path.extension().unwrap().to_str().unwrap();
let data_frame =
df::tabular::read_df_with_extension(version_file, extension, &DFOpts::empty()).await?;
println!("{data_frame}");
assert_eq!(
format!("{data_frame}"),
r"shape: (7, 6)
┌─────────────────┬───────┬───────┬───────┬───────┬────────┐
│ file ┆ label ┆ min_x ┆ min_y ┆ width ┆ height │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 ┆ f64 ┆ i64 ┆ i64 │
╞═════════════════╪═══════╪═══════╪═══════╪═══════╪════════╡
│ train/dog_1.jpg ┆ dog ┆ 101.5 ┆ 32.0 ┆ 385 ┆ 330 │
│ train/dog_1.jpg ┆ dog ┆ 102.5 ┆ 31.0 ┆ 386 ┆ 330 │
│ train/dog_2.jpg ┆ dog ┆ 7.0 ┆ 29.5 ┆ 246 ┆ 247 │
│ train/dog_3.jpg ┆ dog ┆ 19.0 ┆ 63.5 ┆ 376 ┆ 421 │
│ train/cat_1.jpg ┆ cat ┆ 57.0 ┆ 35.5 ┆ 304 ┆ 427 │
│ train/cat_2.jpg ┆ cat ┆ 30.5 ┆ 44.0 ┆ 333 ┆ 396 │
│ images/test.jpg ┆ dog ┆ 2.0 ┆ 3.0 ┆ 100 ┆ 120 │
└─────────────────┴───────┴───────┴───────┴───────┴────────┘"
);
Ok(())
}).await
}
#[test]
fn test_add_exclude_to_simple_select() -> Result<(), OxenError> {
let sql = "SELECT * FROM table";
let result = add_exclude_to_sql(sql)?;
assert_eq!(
result,
"SELECT * EXCLUDE (\"_oxen_id\", \"_oxen_diff_status\", \"_oxen_row_id\", \"_oxen_diff_hash\") FROM table"
);
Ok(())
}
#[test]
fn test_add_exclude_to_complex_select() -> Result<(), OxenError> {
let sql = "SELECT col1, col2, col3 FROM table WHERE col1 = 'value'";
let result = add_exclude_to_sql(sql)?;
assert_eq!(
result,
"SELECT col1, col2, col3 EXCLUDE (\"_oxen_id\", \"_oxen_diff_status\", \"_oxen_row_id\", \"_oxen_diff_hash\") FROM table WHERE col1 = 'value'"
);
Ok(())
}
#[test]
fn test_add_exclude_case_insensitive() -> Result<(), OxenError> {
let sql = "select * from table";
let result = add_exclude_to_sql(sql)?;
assert_eq!(
result,
"SELECT * EXCLUDE (\"_oxen_id\", \"_oxen_diff_status\", \"_oxen_row_id\", \"_oxen_diff_hash\") from table"
);
Ok(())
}
#[test]
fn test_invalid_sql() {
let sql = "DELETE FROM table";
let result = add_exclude_to_sql(sql);
assert!(result.is_err());
}
#[test]
fn test_add_exclude_to_aggregation_query() -> Result<(), OxenError> {
let sql = "SELECT label, COUNT(*) FROM table GROUP BY label";
let result = add_exclude_to_sql(sql)?;
assert_eq!(result, "SELECT label, COUNT(*) FROM table GROUP BY label");
Ok(())
}
async fn add_row_with_list_column(
repo: &crate::model::LocalRepository,
workspace: &crate::model::Workspace,
file_path: &Path,
column_name: &str,
data_type: &str,
list_value: serde_json::Value,
) -> Result<polars::prelude::Series, OxenError> {
use crate::view::data_frames::columns::NewColumn;
let new_column = NewColumn {
name: column_name.to_string(),
data_type: data_type.to_string(),
};
workspaces::data_frames::columns::add(repo, workspace, file_path, &new_column)?;
let json_data = json!({
"file": "list_row.jpg",
"label": "dog",
"min_x": 1.0,
"min_y": 2.0,
"width": 3,
"height": 4,
column_name: list_value,
});
let added = workspaces::data_frames::rows::add(repo, workspace, file_path, &json_data)?;
let row_id = added
.column(OXEN_ID_COL)?
.get(0)?
.to_string()
.trim_matches('"')
.to_string();
let row = workspaces::data_frames::rows::get_by_id(workspace, file_path, &row_id)?;
let column = row.column(column_name)?;
Ok(column.as_materialized_series().clone())
}
fn list_typed_add_row_test_paths() -> std::path::PathBuf {
Path::new("annotations")
.join("train")
.join("bounding_box.csv")
}
#[tokio::test]
async fn test_add_row_with_list_i64_column() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch = repositories::branches::create_checkout(&repo, "test-list-i64")?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = list_typed_add_row_test_paths();
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let series = add_row_with_list_column(
&repo,
&workspace,
&file_path,
"scores",
"list[i64]",
json!([10, 20, 30]),
)
.await?;
let inner = series.list()?.get_as_series(0).unwrap();
let values: Vec<i64> = inner.i64()?.into_iter().map(|v| v.unwrap()).collect();
assert_eq!(values, vec![10, 20, 30]);
Ok(())
})
.await
}
#[tokio::test]
async fn test_add_row_with_list_str_column_preserves_nulls() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch = repositories::branches::create_checkout(&repo, "test-list-str")?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = list_typed_add_row_test_paths();
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let series = add_row_with_list_column(
&repo,
&workspace,
&file_path,
"tags",
"list[str]",
json!(["a", null, "b"]),
)
.await?;
let inner = series.list()?.get_as_series(0).unwrap();
let values: Vec<Option<String>> = inner
.str()?
.into_iter()
.map(|v| v.map(|s| s.to_string()))
.collect();
assert_eq!(
values,
vec![Some("a".to_string()), None, Some("b".to_string())]
);
Ok(())
})
.await
}
#[tokio::test]
async fn test_add_row_with_list_u32_column() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch = repositories::branches::create_checkout(&repo, "test-list-u32")?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = list_typed_add_row_test_paths();
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let series = add_row_with_list_column(
&repo,
&workspace,
&file_path,
"ids",
"list[u32]",
json!([1u32, 2u32, 3u32]),
)
.await?;
let inner = series.list()?.get_as_series(0).unwrap();
let values: Vec<u32> = inner.u32()?.into_iter().map(|v| v.unwrap()).collect();
assert_eq!(values, vec![1, 2, 3]);
Ok(())
})
.await
}
#[tokio::test]
async fn test_add_row_with_list_f64_column() -> Result<(), OxenError> {
if std::env::consts::OS == "windows" {
return Ok(());
}
test::run_training_data_repo_test_fully_committed_async(|repo| async move {
let branch = repositories::branches::create_checkout(&repo, "test-list-f64")?;
let commit = repositories::commits::get_by_id(&repo, &branch.commit_id)?.unwrap();
let workspace_id = UserConfig::identifier()?;
let workspace = repositories::workspaces::create(&repo, &commit, workspace_id, true)?;
let file_path = list_typed_add_row_test_paths();
workspaces::data_frames::index(&repo, &workspace, &file_path).await?;
let series = add_row_with_list_column(
&repo,
&workspace,
&file_path,
"embedding",
"list[f64]",
json!([0.1, 0.2, 0.3]),
)
.await?;
let inner = series.list()?.get_as_series(0).unwrap();
let values: Vec<f64> = inner.f64()?.into_iter().map(|v| v.unwrap()).collect();
assert_eq!(values, vec![0.1, 0.2, 0.3]);
Ok(())
})
.await
}
}