polars-stream 0.53.0

Private crate for the streaming execution engine for the Polars DataFrame library
Documentation
use std::sync::Arc;

use polars_core::schema::Schema;

use super::compute_node_prelude::*;
use crate::async_primitives::wait_group::WaitGroup;
use crate::morsel::{SourceToken, get_ideal_morsel_size};
use crate::nodes::in_memory_sink::InMemorySinkNode;
pub enum RepeatNode {
    GatheringParams {
        value: InMemorySinkNode,
        repeats: InMemorySinkNode,
    },
    Repeating {
        value: DataFrame,
        seq: MorselSeq,
        repeats_left: usize,
    },
}

impl RepeatNode {
    pub fn new(value_schema: Arc<Schema>, repeats_schema: Arc<Schema>) -> Self {
        assert!(value_schema.len() == 1);
        assert!(repeats_schema.len() == 1);
        Self::GatheringParams {
            value: InMemorySinkNode::new(value_schema),
            repeats: InMemorySinkNode::new(repeats_schema),
        }
    }
}

impl ComputeNode for RepeatNode {
    fn name(&self) -> &str {
        "repeat"
    }

    fn update_state(
        &mut self,
        recv: &mut [PortState],
        send: &mut [PortState],
        state: &StreamingExecutionState,
    ) -> PolarsResult<()> {
        assert!(recv.len() == 2 && send.len() == 1);

        if recv[0] == PortState::Done && recv[1] == PortState::Done {
            if let Self::GatheringParams { value, repeats } = self {
                let repeats = repeats.get_output()?.unwrap();
                let repeats_item = repeats.columns()[0].get(0)?;
                let repeats_left = repeats_item.extract::<usize>().unwrap();

                let value = value.get_output()?.unwrap();
                let seq = MorselSeq::default();
                *self = Self::Repeating {
                    value,
                    seq,
                    repeats_left,
                };
            }
        }

        match self {
            Self::GatheringParams { value, repeats } => {
                value.update_state(&mut recv[0..1], &mut [], state)?;
                repeats.update_state(&mut recv[1..2], &mut [], state)?;
                send[0] = PortState::Blocked;
            },
            Self::Repeating { repeats_left, .. } => {
                recv[0] = PortState::Done;
                recv[1] = PortState::Done;
                send[0] = if *repeats_left > 0 {
                    PortState::Ready
                } else {
                    PortState::Done
                };
            },
        }
        Ok(())
    }

    fn spawn<'env, 's>(
        &'env mut self,
        scope: &'s TaskScope<'s, 'env>,
        recv_ports: &mut [Option<RecvPort<'_>>],
        send_ports: &mut [Option<SendPort<'_>>],
        state: &'s StreamingExecutionState,
        join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
    ) {
        assert!(recv_ports.len() == 2 && send_ports.len() == 1);
        match self {
            Self::GatheringParams { value, repeats } => {
                assert!(send_ports[0].is_none());
                if recv_ports[0].is_some() {
                    value.spawn(scope, &mut recv_ports[0..1], &mut [], state, join_handles);
                }
                if recv_ports[1].is_some() {
                    repeats.spawn(scope, &mut recv_ports[1..2], &mut [], state, join_handles);
                }
            },
            Self::Repeating {
                value,
                seq,
                repeats_left,
            } => {
                assert!(recv_ports[0].is_none());
                assert!(recv_ports[1].is_none());

                let mut send = send_ports[0].take().unwrap().serial();

                let ideal_morsel_count = (*repeats_left / get_ideal_morsel_size()).max(1);
                let morsel_count = ideal_morsel_count.next_multiple_of(state.num_pipelines);
                let morsel_size = repeats_left.div_ceil(morsel_count).max(1);

                join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
                    let source_token = SourceToken::new();

                    let wait_group = WaitGroup::default();
                    while *repeats_left > 0 && !source_token.stop_requested() {
                        let height = morsel_size.min(*repeats_left);
                        let df = value.new_from_index(0, height);
                        let mut morsel = Morsel::new(df, *seq, source_token.clone());
                        morsel.set_consume_token(wait_group.token());

                        *seq = seq.successor();
                        *repeats_left -= height;

                        if send.send(morsel).await.is_err() {
                            break;
                        }
                        wait_group.wait().await;
                    }

                    Ok(())
                }));
            },
        }
    }
}