chorus_lib 0.5.0

Choreographic Programming in Rust
Documentation
use std::marker::PhantomData;
use std::thread::{self, JoinHandle};

use chorus_lib::core::{
    ChoreoOp, Choreography, ChoreographyLocation, Located, LocationSet, Projector,
};
use chorus_lib::transport::local::{LocalTransport, LocalTransportChannelBuilder};

#[derive(ChoreographyLocation)]
pub struct Alice;

#[derive(ChoreographyLocation)]
pub struct Bob;

#[derive(ChoreographyLocation)]
pub struct Carol;

pub struct Main;

impl Choreography for Main {
    type L = LocationSet!(Alice, Bob, Carol);

    fn run(self, op: &impl ChoreoOp<Self::L>) {
        let shuffled_list_at_alice = Located::local(vec![9, 1, 4, 7, 5, 2, 3, 0, 6, 8]);
        let sorted_list_at_alice = op.call(Sort::<Alice, Bob, Carol>::new(shuffled_list_at_alice));
        op.locally(Alice, |un| {
            let sorted_list = un.unwrap(&sorted_list_at_alice);
            println!("{sorted_list:?}");
        });
    }
}

pub struct Sort<A, B, C>
where
    A: ChoreographyLocation,
    B: ChoreographyLocation,
    C: ChoreographyLocation,
{
    list_at_a: Located<Vec<i64>, A>,
    phantom: PhantomData<(B, C)>,
}

impl<'a, A, B, C> Sort<A, B, C>
where
    A: ChoreographyLocation,
    B: ChoreographyLocation,
    C: ChoreographyLocation,
{
    pub const fn new(list: Located<Vec<i64>, A>) -> Self {
        Self {
            list_at_a: list,
            phantom: PhantomData,
        }
    }
}

impl<A, B, C> Choreography<Located<Vec<i64>, A>> for Sort<A, B, C>
where
    A: ChoreographyLocation,
    B: ChoreographyLocation,
    C: ChoreographyLocation,
{
    type L = LocationSet!(A, B, C);

    fn run(self, op: &impl ChoreoOp<Self::L>) -> Located<Vec<i64>, A> {
        #[allow(non_snake_case)]
        let (A, B, C) = (A::new(), B::new(), C::new());

        let is_list_small = op.broadcast(
            A,
            op.locally(A, |un| {
                let list = un.unwrap(&self.list_at_a);
                list.len() <= 1
            }),
        );

        if is_list_small {
            return self.list_at_a;
        }

        let pivot_at_a = op.locally(A, |un| {
            let list = un.unwrap(&self.list_at_a);
            list.len() / 2
        });

        let prefix_at_a = op.locally(A, |un| {
            let list = un.unwrap(&self.list_at_a);
            let pivot = un.unwrap(&pivot_at_a);
            list[..*pivot].to_vec()
        });
        let prefix_at_b = op.comm(A, B, &prefix_at_a);
        let sorted_prefix_at_b = op.call(Sort::<B, C, A>::new(prefix_at_b));

        let suffix_at_a = op.locally(A, |un| {
            let list = un.unwrap(&self.list_at_a);
            let pivot = un.unwrap(&pivot_at_a);
            list[*pivot..].to_vec()
        });
        let suffix_at_c = op.comm(A, C, &suffix_at_a);
        let sorted_suffix_at_c = op.call(Sort::<C, A, B>::new(suffix_at_c));

        op.call(Merge::new(sorted_prefix_at_b, sorted_suffix_at_c))
    }
}

pub struct Merge<B, C>
where
    B: ChoreographyLocation,
    C: ChoreographyLocation,
{
    prefix_at_b: Located<Vec<i64>, B>,
    suffix_at_c: Located<Vec<i64>, C>,
}

impl<B, C> Merge<B, C>
where
    B: ChoreographyLocation,
    C: ChoreographyLocation,
{
    pub const fn new(prefix_at_b: Located<Vec<i64>, B>, suffix_at_c: Located<Vec<i64>, C>) -> Self {
        Self {
            prefix_at_b,
            suffix_at_c,
        }
    }
}

impl<A, B, C> Choreography<Located<Vec<i64>, A>> for Merge<B, C>
where
    A: ChoreographyLocation,
    B: ChoreographyLocation,
    C: ChoreographyLocation,
{
    type L = LocationSet!(A, B, C);

    fn run(self, op: &impl ChoreoOp<Self::L>) -> Located<Vec<i64>, A> {
        #[allow(non_snake_case)]
        let (A, B, C) = (A::new(), B::new(), C::new());

        let is_prefix_empty = op.broadcast(
            B,
            op.locally(B, |un| {
                let prefix = un.unwrap(&self.prefix_at_b);
                prefix.is_empty()
            }),
        );

        if is_prefix_empty {
            let suffix_at_c = op.locally(C, |un| un.unwrap(&self.suffix_at_c).to_vec());
            return op.comm(C, A, &suffix_at_c);
        }

        let is_suffix_empty = op.broadcast(
            C,
            op.locally(C, |un| {
                let suffix = un.unwrap(&self.suffix_at_c);
                suffix.is_empty()
            }),
        );

        if is_suffix_empty {
            let prefix_at_b = op.locally(B, |un| un.unwrap(&self.prefix_at_b).to_vec());
            return op.comm(B, A, &prefix_at_b);
        }

        let head_of_prefix_at_b = op.locally(B, |un| un.unwrap(&self.prefix_at_b)[0]);
        let head_of_prefix_at_c = op.comm(B, C, &head_of_prefix_at_b);
        let head_of_suffix_at_c = op.locally(C, |un| un.unwrap(&self.suffix_at_c)[0]);

        let is_head_of_prefix_smaller = op.broadcast(
            C,
            op.locally(C, |un| {
                let head_of_prefix = un.unwrap(&head_of_prefix_at_c);
                let head_of_suffix = un.unwrap(&head_of_suffix_at_c);
                head_of_prefix <= head_of_suffix
            }),
        );

        if is_head_of_prefix_smaller {
            let head_of_prefix_at_a = op.comm(B, A, &head_of_prefix_at_b);
            let prefix_at_b = op.locally(B, |un| {
                let prefix = un.unwrap(&self.prefix_at_b);
                prefix[1..].to_vec()
            });
            let partial_at_a = op.call(Merge::new(prefix_at_b, self.suffix_at_c));
            op.locally(A, |un| {
                let partial = un.unwrap(&partial_at_a);
                let head_of_prefix = un.unwrap(&head_of_prefix_at_a);
                [head_of_prefix]
                    .into_iter()
                    .chain(partial)
                    .copied()
                    .collect()
            })
        } else {
            let head_of_suffix_at_a = op.comm(C, A, &head_of_suffix_at_c);
            let suffix_at_c = op.locally(C, |un| {
                let suffix = un.unwrap(&self.suffix_at_c);
                suffix[1..].to_vec()
            });
            let partial_at_a = op.call(Merge::new(self.prefix_at_b, suffix_at_c));
            op.locally(A, |un| {
                let partial = un.unwrap(&partial_at_a);
                let head_of_suffix = un.unwrap(&head_of_suffix_at_a);
                [head_of_suffix]
                    .into_iter()
                    .chain(partial)
                    .copied()
                    .collect()
            })
        }
    }
}

fn main() {
    let mut handles = Vec::new();
    let transport_channel = LocalTransportChannelBuilder::new()
        .with(Alice)
        .with(Bob)
        .with(Carol)
        .build();

    {
        let transport = LocalTransport::new(Alice, transport_channel.clone());
        handles.push(thread::spawn(move || {
            let projector = Projector::new(Alice, transport);
            projector.epp_and_run(Main);
        }));
    }

    {
        let transport = LocalTransport::new(Bob, transport_channel.clone());
        handles.push(thread::spawn(move || {
            let projector = Projector::new(Bob, transport);
            projector.epp_and_run(Main);
        }));
    }

    {
        let transport = LocalTransport::new(Carol, transport_channel.clone());
        handles.push(thread::spawn(move || {
            let projector = Projector::new(Carol, transport);
            projector.epp_and_run(Main);
        }));
    }

    handles.into_iter().try_for_each(JoinHandle::join).unwrap();
}