use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
use futures_core::Stream;
pub type Data = Vec<u8>;
pub trait ImageGenerator {
type Error: core::error::Error + Send + Sync + 'static;
fn create(
&self,
prompt: Prompt,
size: Size,
) -> impl Stream<Item = Result<Data, Self::Error>> + Unpin + Send;
fn edit(
&self,
prompt: Prompt,
mask: &[u8],
) -> impl Stream<Item = Result<Data, Self::Error>> + Unpin + Send;
}
macro_rules! impl_image_generator {
($($name:ident),*) => {
$(
impl<T: ImageGenerator> ImageGenerator for $name<T> {
type Error = T::Error;
fn create(
&self,
prompt: Prompt,
size: Size,
) -> impl Stream<Item = Result<Data, Self::Error>> + Unpin + Send {
T::create(self, prompt, size)
}
fn edit(
&self,
prompt: Prompt,
mask: &[u8],
) -> impl Stream<Item = Result<Data, Self::Error>> + Unpin + Send {
T::edit(self, prompt, mask)
}
}
)*
};
}
impl_image_generator!(Arc, Box);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Prompt {
text: String,
image: Vec<Data>,
}
impl Prompt {
#[must_use]
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
image: Vec::new(),
}
}
#[must_use]
pub fn text(&self) -> &str {
&self.text
}
#[must_use]
pub fn images(&self) -> &[Data] {
&self.image
}
#[must_use]
pub fn with_image(mut self, image: Data) -> Self {
self.image.push(image);
self
}
}
impl From<String> for Prompt {
fn from(text: String) -> Self {
Self::new(text)
}
}
impl From<&str> for Prompt {
fn from(text: &str) -> Self {
Self::new(text)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Size {
width: u32,
height: u32,
}
impl Size {
#[must_use]
pub const fn new(width: u32, height: u32) -> Self {
Self { width, height }
}
#[must_use]
pub const fn square(size: u32) -> Self {
Self {
width: size,
height: size,
}
}
#[must_use]
pub const fn width(&self) -> u32 {
self.width
}
#[must_use]
pub const fn height(&self) -> u32 {
self.height
}
#[must_use]
pub const fn pixel_count(&self) -> u64 {
self.width as u64 * self.height as u64
}
#[must_use]
pub const fn is_square(&self) -> bool {
self.width == self.height
}
}
#[cfg(test)]
mod tests {
use core::convert::Infallible;
use super::*;
use alloc::vec;
use futures_lite::StreamExt;
struct MockImageGenerator;
impl ImageGenerator for MockImageGenerator {
type Error = Infallible;
fn create(
&self,
prompt: Prompt,
_size: Size,
) -> impl Stream<Item = Result<Data, Self::Error>> + Send {
let prompt_bytes = prompt.text.as_bytes();
let chunk1 = prompt_bytes.to_vec();
let chunk2 = vec![0xFF, 0xD8, 0xFF, 0xE0]; let chunk3 = vec![0x00; 100];
futures_lite::stream::iter(vec![chunk1, chunk2, chunk3].into_iter().map(Ok))
}
fn edit(
&self,
prompt: Prompt,
_mask: &[u8],
) -> impl Stream<Item = Result<Data, Self::Error>> + Send {
let prompt_bytes = prompt.text.as_bytes();
let chunk1 = prompt_bytes.to_vec();
let chunk2 = vec![0xFF, 0xD8, 0xFF, 0xE0]; let chunk3 = vec![0x00; 100];
futures_lite::stream::iter(vec![chunk1, chunk2, chunk3].into_iter().map(Ok))
}
}
#[tokio::test]
async fn image_generation() {
let generator = MockImageGenerator;
let mut stream = generator.create(Prompt::new("a cat"), Size::square(256));
let mut chunks = Vec::new();
while let Some(chunk) = stream.next().await {
chunks.push(chunk.unwrap());
}
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], b"a cat".to_vec());
assert_eq!(chunks[1], vec![0xFF, 0xD8, 0xFF, 0xE0]);
assert_eq!(chunks[2], vec![0x00; 100]);
}
#[tokio::test]
async fn image_generation_empty_prompt() {
let generator = MockImageGenerator;
let mut stream = generator.create(Prompt::new(""), Size::square(256));
let mut chunks = Vec::new();
while let Some(chunk) = stream.next().await {
chunks.push(chunk.unwrap());
}
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], b"".to_vec());
assert_eq!(chunks[1], vec![0xFF, 0xD8, 0xFF, 0xE0]);
assert_eq!(chunks[2], vec![0x00; 100]);
}
#[tokio::test]
async fn image_generation_long_prompt() {
let generator = MockImageGenerator;
let long_prompt = "a very detailed and elaborate description of a beautiful landscape with mountains, rivers, and forests";
let mut stream = generator.create(Prompt::new(long_prompt), Size::square(512));
let mut total_bytes = 0;
while let Some(chunk) = stream.next().await {
total_bytes += chunk.unwrap().len();
}
assert_eq!(total_bytes, long_prompt.len() + 4 + 100);
}
#[tokio::test]
async fn data_type_alias() {
let data: Data = vec![1, 2, 3, 4];
assert_eq!(data.len(), 4);
assert_eq!(data[0], 1);
assert_eq!(data[3], 4);
}
#[test]
fn data_operations() {
let mut data: Data = vec![0xFF; 1024];
assert_eq!(data.len(), 1024);
data.push(0x00);
assert_eq!(data.len(), 1025);
assert_eq!(data[1024], 0x00);
data.extend_from_slice(&[0x01, 0x02]);
assert_eq!(data.len(), 1027);
assert_eq!(data[1025], 0x01);
assert_eq!(data[1026], 0x02);
}
}