ella_server/client/
publisher.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use std::{fmt::Debug, panic::resume_unwind, task::Poll};

use arrow_flight::{
    encode::FlightDataEncoderBuilder,
    sql::{CommandStatementUpdate, ProstMessageExt},
    FlightData, FlightDescriptor,
};
use datafusion::arrow::record_batch::RecordBatch;
use ella_engine::{registry::TableId, EngineError};
use flume::r#async::SendSink;
use futures::{
    stream::{AbortHandle, Abortable},
    FutureExt, Sink, SinkExt, StreamExt,
};
use prost::Message;
use tokio::task::JoinHandle;

use super::EllaClient;

pub struct FlightPublisher {
    send: SendSink<'static, RecordBatch>,
    handle: JoinHandle<crate::Result<()>>,
    table: TableId<'static>,
    stop: AbortHandle,
}

impl Debug for FlightPublisher {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("FlightPublisher")
            .field("table", &self.table)
            .finish_non_exhaustive()
    }
}

impl FlightPublisher {
    pub fn new(mut client: EllaClient, table: TableId<'static>) -> Self {
        let (send, recv) = flume::bounded(1);
        let send = send.into_sink();
        let descriptor = FlightDescriptor::new_cmd(
            CommandStatementUpdate {
                query: format!("insert into {} table this", table),
                transaction_id: None,
            }
            .as_any()
            .encode_to_vec(),
        );
        let (stop, reg) = AbortHandle::new_pair();
        let header = futures::stream::once(async { FlightData::new().with_descriptor(descriptor) });
        let stream = FlightDataEncoderBuilder::new()
            .build(recv.into_stream().map(Ok))
            .map(|res| res.unwrap());
        let stream = Abortable::new(stream, reg);

        let handle = tokio::spawn(async move {
            let mut resp = client.flight.do_put(header.chain(stream)).await?;
            resp.message().await.map_err(crate::ClientError::from)?;
            Ok(())
        });
        Self {
            send,
            handle,
            table,
            stop,
        }
    }

    fn get_error(&mut self) -> crate::Error {
        match (&mut self.handle).now_or_never() {
            Some(Ok(Ok(_))) | None => crate::ClientError::TopicClosed.into(),
            Some(Err(err)) => {
                EngineError::worker_panic("flight_publisher", &err.into_panic()).into()
            }
            Some(Ok(Err(err))) => err,
        }
    }
}

impl Sink<RecordBatch> for FlightPublisher {
    type Error = crate::Error;

    #[inline]
    fn poll_ready(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.send.poll_ready_unpin(cx).map_err(|_| self.get_error())
    }

    #[inline]
    fn start_send(
        mut self: std::pin::Pin<&mut Self>,
        item: RecordBatch,
    ) -> Result<(), Self::Error> {
        self.send
            .start_send_unpin(item)
            .map_err(|_| self.get_error())
    }

    #[inline]
    fn poll_flush(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.send.poll_flush_unpin(cx).map_err(|_| self.get_error())
    }

    fn poll_close(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        if futures::ready!(self.send.poll_close_unpin(cx)).is_err() {
            return Poll::Ready(Err(self.get_error()));
        }
        if !self.stop.is_aborted() {
            self.stop.abort();
        }
        Poll::Ready(match futures::ready!(self.handle.poll_unpin(cx)) {
            Ok(Ok(())) => Ok(()),
            Ok(Err(err)) => Err(err),
            Err(err) => resume_unwind(err.into_panic()),
        })
    }
}