1use std::sync::Arc;
4
5use arrow_array::RecordBatch;
6use arrow_schema::{Schema, SchemaRef};
7
8use crate::errors::{Result, RpcError};
9use crate::log::{LogLevel, LogMessage};
10use crate::wire::Metadata;
11
12pub(crate) enum Emitted {
14 Batch {
15 batch: RecordBatch,
16 metadata: Option<Metadata>,
17 },
18 Log(LogMessage),
19}
20
21pub struct OutputCollector {
23 schema: SchemaRef,
24 pub(crate) items: Vec<Emitted>,
25 finished: bool,
26 is_producer: bool,
27}
28
29impl OutputCollector {
30 pub(crate) fn new(schema: SchemaRef, is_producer: bool) -> Self {
31 Self {
32 schema,
33 items: Vec::new(),
34 finished: false,
35 is_producer,
36 }
37 }
38
39 pub fn schema(&self) -> SchemaRef {
41 self.schema.clone()
42 }
43
44 pub fn emit(&mut self, batch: RecordBatch) -> Result<()> {
46 if batch.schema() != self.schema {
47 return Err(RpcError::runtime_error(format!(
48 "emit(): schema mismatch — expected {:?}, got {:?}",
49 self.schema.fields(),
50 batch.schema().fields()
51 )));
52 }
53 self.items.push(Emitted::Batch {
54 batch,
55 metadata: None,
56 });
57 Ok(())
58 }
59
60 pub fn emit_with_metadata(&mut self, batch: RecordBatch, metadata: Metadata) -> Result<()> {
63 if batch.schema() != self.schema {
64 return Err(RpcError::runtime_error(format!(
65 "emit_with_metadata(): schema mismatch — expected {:?}, got {:?}",
66 self.schema.fields(),
67 batch.schema().fields()
68 )));
69 }
70 self.items.push(Emitted::Batch {
71 batch,
72 metadata: Some(metadata),
73 });
74 Ok(())
75 }
76
77 pub fn finish(&mut self) {
79 self.finished = true;
80 }
81
82 pub fn finished(&self) -> bool {
83 self.finished
84 }
85
86 pub fn client_log(&mut self, level: LogLevel, message: impl Into<String>) {
88 self.items
89 .push(Emitted::Log(LogMessage::new(level, message)));
90 }
91
92 pub fn client_log_with(&mut self, msg: LogMessage) {
94 self.items.push(Emitted::Log(msg));
95 }
96
97 pub fn is_producer(&self) -> bool {
98 self.is_producer
99 }
100}
101
102pub trait ProducerState: Send {
104 fn produce(&mut self, out: &mut OutputCollector, ctx: &CallContext) -> Result<()>;
105
106 fn on_cancel(&mut self, _ctx: &CallContext) {}
108
109 fn batch_limit(&self) -> Option<usize> {
114 None
115 }
116
117 fn encode_state(&self) -> Result<Vec<u8>> {
122 Err(RpcError::runtime_error(
123 "producer state does not implement encode_state(); \
124 override this method or register the method via MethodInfo::stream_with_codec",
125 ))
126 }
127}
128
129pub trait ExchangeState: Send {
131 fn exchange(
132 &mut self,
133 input: &RecordBatch,
134 out: &mut OutputCollector,
135 ctx: &CallContext,
136 ) -> Result<()>;
137
138 fn on_cancel(&mut self, _ctx: &CallContext) {}
139
140 fn encode_state(&self) -> Result<Vec<u8>> {
143 Err(RpcError::runtime_error(
144 "exchange state does not implement encode_state(); \
145 override this method or register the method via MethodInfo::stream_with_codec",
146 ))
147 }
148}
149
150pub struct StreamResult {
153 pub output_schema: SchemaRef,
154 pub input_schema: Option<SchemaRef>,
156 pub state: StreamStateKind,
157 pub header: Option<RecordBatch>,
159 pub header_metadata: Option<Metadata>,
161}
162
163pub enum StreamStateKind {
164 Producer(Box<dyn ProducerState>),
165 Exchange(Box<dyn ExchangeState>),
166}
167
168impl StreamResult {
169 pub fn producer(schema: SchemaRef, state: Box<dyn ProducerState>) -> Self {
170 Self {
171 output_schema: schema,
172 input_schema: None,
173 state: StreamStateKind::Producer(state),
174 header: None,
175 header_metadata: None,
176 }
177 }
178
179 pub fn exchange(
180 output_schema: SchemaRef,
181 input_schema: SchemaRef,
182 state: Box<dyn ExchangeState>,
183 ) -> Self {
184 Self {
185 output_schema,
186 input_schema: Some(input_schema),
187 state: StreamStateKind::Exchange(state),
188 header: None,
189 header_metadata: None,
190 }
191 }
192
193 pub fn with_header(mut self, header: RecordBatch) -> Self {
194 self.header = Some(header);
195 self
196 }
197}
198
199#[cfg(feature = "http")]
205#[doc(hidden)]
206pub fn producer_decoder<S>() -> crate::server::StateDecoder
207where
208 S: ProducerState + crate::stream_codec::StreamStateCodec + 'static,
209{
210 Arc::new(|bytes: &[u8]| Ok(StreamStateKind::Producer(Box::new(S::decode(bytes)?))))
211}
212
213#[cfg(feature = "http")]
218#[doc(hidden)]
219pub fn exchange_decoder<S>() -> crate::server::StateDecoder
220where
221 S: ExchangeState + crate::stream_codec::StreamStateCodec + 'static,
222{
223 Arc::new(|bytes: &[u8]| Ok(StreamStateKind::Exchange(Box::new(S::decode(bytes)?))))
224}
225
226pub(crate) fn empty_schema() -> SchemaRef {
227 Arc::new(Schema::empty())
228}
229
230pub use crate::server::CallContext;