#![warn(missing_docs)]
use crate::nn::transform::Transform;
use crate::{Tensor, TensorElement};
use std::error::Error;
use std::str::FromStr;
pub trait DataSet {
fn new(
img_labels: Vec<String>,
img_dir: &'static str,
transform: Option<Transform>,
target_transform: Option<Transform>,
) -> Self;
fn size(&self) -> usize;
fn get<'a, T>(&self, idx: usize) -> (Tensor<'a, T>, String)
where
T: TensorElement,
<T as FromStr>::Err: Error;
}
pub struct CustomImageDataset {
img_labels: Vec<String>,
img_dir: &'static str,
transform: Option<Transform>,
target_transform: Option<Transform>,
}
impl DataSet for CustomImageDataset {
fn new(
img_labels: Vec<String>,
img_dir: &'static str,
transform: Option<Transform>,
target_transform: Option<Transform>,
) -> Self {
Self {
img_labels,
img_dir,
transform,
target_transform,
}
}
fn size(&self) -> usize {
self.img_labels.len()
}
fn get<'a, T>(&self, idx: usize) -> (Tensor<'a, T>, String)
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
todo!()
}
}