use crate::wasm_compat::*;
#[allow(unused_imports)] use futures::join;
use futures::stream;
use std::future::Future;
pub trait Op: WasmCompatSend + WasmCompatSync {
type Input: WasmCompatSend + WasmCompatSync;
type Output: WasmCompatSend + WasmCompatSync;
fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + WasmCompatSend;
fn batch_call<I>(
&self,
n: usize,
input: I,
) -> impl Future<Output = Vec<Self::Output>> + WasmCompatSend
where
I: IntoIterator<Item = Self::Input> + WasmCompatSend,
I::IntoIter: WasmCompatSend,
Self: Sized,
{
use futures::stream::StreamExt;
async move {
stream::iter(input)
.map(|input| self.call(input))
.buffered(n)
.collect()
.await
}
}
fn map<F, Input>(self, f: F) -> Sequential<Self, Map<F, Self::Output>>
where
F: Fn(Self::Output) -> Input + WasmCompatSend + WasmCompatSync,
Input: WasmCompatSend + WasmCompatSync,
Self: Sized,
{
Sequential::new(self, Map::new(f))
}
fn then<F, Fut>(self, f: F) -> Sequential<Self, Then<F, Fut::Output>>
where
F: Fn(Self::Output) -> Fut + Send + WasmCompatSync,
Fut: Future + WasmCompatSend + WasmCompatSync,
Fut::Output: WasmCompatSend + WasmCompatSync,
Self: Sized,
{
Sequential::new(self, Then::new(f))
}
fn chain<T>(self, op: T) -> Sequential<Self, T>
where
T: Op<Input = Self::Output>,
Self: Sized,
{
Sequential::new(self, op)
}
fn lookup<I, Input>(
self,
index: I,
n: usize,
) -> Sequential<Self, Lookup<I, Self::Output, Input>>
where
I: vector_store::VectorStoreIndex,
Input: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>,
Self::Output: Into<String>,
Self: Sized,
{
Sequential::new(self, Lookup::new(index, n))
}
fn prompt<P>(self, prompt: P) -> Sequential<Self, Prompt<P, Self::Output>>
where
P: completion::Prompt,
Self::Output: Into<String>,
Self: Sized,
{
Sequential::new(self, Prompt::new(prompt))
}
}
impl<T: Op> Op for &T {
type Input = T::Input;
type Output = T::Output;
#[inline]
async fn call(&self, input: Self::Input) -> Self::Output {
(*self).call(input).await
}
}
pub struct Sequential<Op1, Op2> {
prev: Op1,
op: Op2,
}
impl<Op1, Op2> Sequential<Op1, Op2> {
pub(crate) fn new(prev: Op1, op: Op2) -> Self {
Self { prev, op }
}
}
impl<Op1, Op2> Op for Sequential<Op1, Op2>
where
Op1: Op,
Op2: Op<Input = Op1::Output>,
{
type Input = Op1::Input;
type Output = Op2::Output;
#[inline]
async fn call(&self, input: Self::Input) -> Self::Output {
let prev = self.prev.call(input).await;
self.op.call(prev).await
}
}
use super::agent_ops::{Lookup, Prompt};
use crate::{completion, vector_store};
pub struct Map<F, Input> {
f: F,
_t: std::marker::PhantomData<Input>,
}
impl<F, Input> Map<F, Input> {
pub(crate) fn new(f: F) -> Self {
Self {
f,
_t: std::marker::PhantomData,
}
}
}
impl<F, Input, Output> Op for Map<F, Input>
where
F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
Input: WasmCompatSend + WasmCompatSync,
Output: WasmCompatSend + WasmCompatSync,
{
type Input = Input;
type Output = Output;
#[inline]
async fn call(&self, input: Self::Input) -> Self::Output {
(self.f)(input)
}
}
pub fn map<F, Input, Output>(f: F) -> Map<F, Input>
where
F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync,
Input: WasmCompatSend + WasmCompatSync,
Output: WasmCompatSend + WasmCompatSync,
{
Map::new(f)
}
pub struct Passthrough<T> {
_t: std::marker::PhantomData<T>,
}
impl<T> Passthrough<T> {
pub(crate) fn new() -> Self {
Self {
_t: std::marker::PhantomData,
}
}
}
impl<T> Op for Passthrough<T>
where
T: WasmCompatSend + WasmCompatSync,
{
type Input = T;
type Output = T;
async fn call(&self, input: Self::Input) -> Self::Output {
input
}
}
pub fn passthrough<T>() -> Passthrough<T>
where
T: WasmCompatSend + WasmCompatSync,
{
Passthrough::new()
}
pub struct Then<F, Input> {
f: F,
_t: std::marker::PhantomData<Input>,
}
impl<F, Input> Then<F, Input> {
pub(crate) fn new(f: F) -> Self {
Self {
f,
_t: std::marker::PhantomData,
}
}
}
impl<F, Input, Fut> Op for Then<F, Input>
where
F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
Input: WasmCompatSend + WasmCompatSync,
Fut: Future + WasmCompatSend,
Fut::Output: WasmCompatSend + WasmCompatSync,
{
type Input = Input;
type Output = Fut::Output;
#[inline]
async fn call(&self, input: Self::Input) -> Self::Output {
(self.f)(input).await
}
}
pub fn then<F, Input, Fut>(f: F) -> Then<F, Input>
where
F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync,
Input: WasmCompatSend + WasmCompatSync,
Fut: Future + WasmCompatSend,
Fut::Output: WasmCompatSend + WasmCompatSync,
{
Then::new(f)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_sequential_constructor() {
let op1 = map(|x: i32| x + 1);
let op2 = map(|x: i32| x * 2);
let op3 = map(|x: i32| x * 3);
let pipeline = Sequential::new(Sequential::new(op1, op2), op3);
let result = pipeline.call(1).await;
assert_eq!(result, 12);
}
#[tokio::test]
async fn test_sequential_chain() {
let pipeline = map(|x: i32| x + 1)
.map(|x| x * 2)
.then(|x| async move { x * 3 });
let result = pipeline.call(1).await;
assert_eq!(result, 12);
}
}