liquid_ml/dataframe/
distributed_dataframe.rs

1//! Defines functionality for a data frame that is split across different
2//! physical machines.
3use crate::dataframe::{local_dataframe::LocalDataFrame, Row, Rower, Schema};
4use crate::error::LiquidError;
5use crate::kv::{KVStore, Key};
6use crate::network::{Client, FramedStream};
7use bincode::{deserialize, serialize};
8use futures::stream::{SelectAll, StreamExt};
9use log::{debug, info};
10use rand::{self, Rng};
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use sorer::dataframe::{Column, Data, SorTerator};
13use std::cmp;
14use std::collections::HashMap;
15use std::fs::File;
16use std::io::{BufRead, BufReader};
17use std::ops::Range;
18use std::sync::Arc;
19use tokio::sync::{
20    mpsc::{self, Receiver, Sender},
21    Mutex, Notify, RwLock,
22};
23
24/// Represents a distributed, immutable data frame which contains data stored
25/// in a columnar format and a well defined [`Schema`]. Provides convenient
26/// `map` and `filter` methods that operate on the entire distributed data
27/// frame (ie, across different machines) with a given [`Rower`]
28///
29/// [`Rower`]: trait.Rower.html
30#[derive(Debug)]
31pub struct DistributedDataFrame {
32    /// The `Schema` of this `DistributedDataFrame`
33    pub schema: Schema,
34    /// The name of this `DistributedDataFrame`. Must be unique in a `LiquidML`
35    /// instance
36    pub df_name: String,
37    /// A map of the range of row indices to the `Key`s that point to the chunk
38    /// of data with those rows. Not all `Key`s in this map belong to this node
39    /// of the `DistributedDataFrame`, some may belong to other nodes
40    pub df_chunk_map: HashMap<Range<usize>, Key>,
41    /// The number of rows in this entire `DistributedDataFrame`
42    pub num_rows: usize,
43    /// The id of the node this `DistributedDataFrame` is running on
44    pub node_id: usize,
45    /// How many nodes are there in this `DistributedDataFrame`?
46    pub num_nodes: usize,
47    /// What's the address of the `Server`?
48    pub server_addr: String,
49    /// What's my IP address?
50    pub my_ip: String,
51    /// Used for communication with other nodes in this `DistributedDataFrame`
52    network: Arc<Mutex<Client<DistributedDFMsg>>>,
53    /// The `KVStore`, which stores the serialized data owned by this
54    /// `DistributedDataFrame` and deserialized cached data that may or may
55    /// not belong to this node
56    kv: Arc<KVStore<LocalDataFrame>>,
57    /// Used for processing messages so that the asynchronous task running
58    /// the `process_message` function can notify other asynchronous tasks
59    /// when the `row` of this `DistributedDataFrame` is ready to use for
60    /// operations (such as returning the result to the `get_row` function
61    internal_notifier: Arc<Notify>,
62    /// Is mutated by the asynchronous `process_message` task to be a requested
63    /// row when the network responds to `GetRow` requests, to enable getter
64    /// methods for data such as `get_row`
65    row: Arc<RwLock<Row>>,
66    /// A notifier that gets notified when the `Server` has sent a `Kill`
67    /// message to this `DistributedDataFrame`'s network `Client`
68    _kill_notifier: Arc<Notify>,
69    /// Used for lower level messages, such as sending arbitrary `Rower`s
70    blob_receiver: Mutex<Receiver<Vec<u8>>>,
71    /// Used for processing filter results TODO: maybe a better way to do this
72    filter_results: Mutex<Receiver<DistributedDFMsg>>,
73}
74
75/// Represents the kinds of messages sent between `DistributedDataFrame`s
76#[derive(Debug, Serialize, Deserialize, Clone)]
77pub(crate) enum DistributedDFMsg {
78    /// A messaged used to request a `Row` with the given index from another
79    /// node in a `DistributedDataFrame`
80    GetRow(usize),
81    /// A message used to respond to `GetRow` messages with the requested row
82    Row(Row),
83    /// A message used to tell the 1st node the results from using the `filter`
84    /// method. If there were no rows after filtering, then `filtered_df_key`
85    /// is `None` and `num_rows` is `0`.
86    FilterResult {
87        num_rows: usize,
88        filtered_df_key: Option<Key>,
89    },
90    /// A message used to share random blobs of data with other nodes. This
91    /// provides a lower level interface to facilitate other kinds of messages,
92    /// for example sending rowers when performing `map`/`filter`.
93    Blob(Vec<u8>),
94    /// Used to inform other nodes in a `DistributedDataFrame` the required
95    /// information for other nodes to construct a new `DistributedDataFrame`
96    /// struct that is consistent across all nodes.
97    Initialization {
98        schema: Schema,
99        df_chunk_map: HashMap<Range<usize>, Key>,
100    },
101}
102
103impl DistributedDataFrame {
104    /// Creates a new `DistributedDataFrame` from the given file. It is
105    /// assumed that node 1 contains the file with the given `file_name`.
106    /// Node 1 will then parse that file and distribute chunks to other nodes
107    /// over the network, so if network latency is a concern you should not
108    /// use this method.
109    pub(crate) async fn from_sor(
110        server_addr: &str,
111        my_ip: &str,
112        file_name: &str,
113        kv: Arc<KVStore<LocalDataFrame>>,
114        df_name: &str,
115        num_nodes: usize,
116    ) -> Result<Arc<Self>, LiquidError> {
117        // make a chunking iterator for the sor file
118        let sor_terator = if kv.id == 1 {
119            let total_newlines = count_new_lines(file_name);
120            let max_rows_per_node = total_newlines / num_nodes;
121            let schema = sorer::schema::infer_schema(file_name)?;
122            info!(
123                "Total newlines: {} max rows per node: {}",
124                total_newlines, max_rows_per_node
125            );
126            info!("Inferred schema: {:?}", &schema);
127            Some(SorTerator::new(file_name, schema, max_rows_per_node))
128        } else {
129            None
130        };
131        DistributedDataFrame::from_iter(
132            server_addr,
133            my_ip,
134            sor_terator,
135            kv,
136            df_name,
137            num_nodes,
138        )
139        .await
140    }
141
142    /// Creates a new `DataFrame` from the given iterator. The iterator is
143    /// used only on node 1, which calls `next` on it and distributes chunks
144    /// concurrently.
145    pub(crate) async fn from_iter(
146        server_addr: &str,
147        my_ip: &str,
148        iter: Option<impl Iterator<Item = Vec<Column>>>,
149        kv: Arc<KVStore<LocalDataFrame>>,
150        df_name: &str,
151        num_nodes: usize,
152    ) -> Result<Arc<Self>, LiquidError> {
153        // Figure out what node we are supposed to be
154        let node_id = kv.id;
155        // initialize some other required fields of self so as not to duplicate
156        // code in if branches
157        let (blob_sender, blob_receiver) = mpsc::channel(2);
158        // used for internal messaging processing so that the asynchronous
159        // messaging task can notify other tasks when `self.row` is ready
160        let internal_notifier = Arc::new(Notify::new());
161        // so that our network client can notify us when they get a Kill
162        // signal
163        let _kill_notifier = Arc::new(Notify::new());
164        // so that our client only connects to clients for this dataframe
165        let df_network_name = format!("ddf-{}", df_name);
166        // for processing results when distributed filtering is performed
167        // on this `DistributedDataFrame`
168        let (filter_results_sender, filter_results) = mpsc::channel(num_nodes);
169        let filter_results = Mutex::new(filter_results);
170
171        let (network, mut read_streams, __kill_notifier) =
172            Client::register_network(
173                kv.network.clone(),
174                df_network_name.to_string(),
175            )
176            .await?;
177        assert_eq!(node_id, { network.lock().await.id });
178
179        // Node 1 is responsible for sending out chunks
180        if node_id == 1 {
181            // Distribute the chunked sor file round-robin style
182            let mut df_chunk_map = HashMap::new();
183            let mut cur_num_rows = 0;
184            let mut schema = None;
185            {
186                // in each iteration, create a future sends a chunk to a node
187                let mut chunk_idx = 0;
188                for chunk in iter.unwrap().into_iter() {
189                    if chunk_idx == 0 {
190                        schema = Some(Schema::from(&chunk));
191                    }
192
193                    let ldf = LocalDataFrame::from(chunk);
194                    if chunk_idx > 0 {
195                        // assert all chunks have the same schema
196                        assert_eq!(schema.as_ref(), Some(ldf.get_schema()));
197                    }
198
199                    // make the key that will be associated with this chunk
200                    let key =
201                        Key::generate(df_name, (chunk_idx % num_nodes) + 1);
202                    // add this chunk range and key to our <range, key> map
203                    df_chunk_map.insert(
204                        Range {
205                            start: cur_num_rows,
206                            end: cur_num_rows + ldf.n_rows(),
207                        },
208                        key.clone(),
209                    );
210                    cur_num_rows += ldf.n_rows();
211
212                    let kv_ptr = kv.clone();
213                    tokio::spawn(async move {
214                        kv_ptr.put(key, ldf).await.unwrap();
215                    });
216
217                    // NOTE: might need to do some tuning on when to join the
218                    // futures here, possibly even dynamically figure out some
219                    // value to smooth over the tradeoff between memory and
220                    // speed
221                    chunk_idx += 1;
222                }
223
224                // we are almost done distributing chunks
225                info!("Finished distributing {} SoR chunks", chunk_idx);
226            }
227
228            // Create an Initialization message that holds all the information
229            // related to this DistributedDataFrame, the Schema and the map
230            // of the range of indices that each chunk holds and the `Key`
231            // associated with that chunk
232            let schema = schema.unwrap();
233            let intro_msg = DistributedDFMsg::Initialization {
234                schema: schema.clone(),
235                df_chunk_map: df_chunk_map.clone(),
236            };
237
238            // Broadcast the initialization message to all nodes
239            network.lock().await.broadcast(intro_msg).await?;
240            debug!("Node 1 sent the initialization message to all nodes");
241
242            let row = Arc::new(RwLock::new(Row::new(&schema)));
243
244            let ddf = Arc::new(DistributedDataFrame {
245                schema,
246                df_name: df_name.to_string(),
247                df_chunk_map,
248                num_rows: cur_num_rows,
249                network,
250                node_id,
251                num_nodes,
252                server_addr: server_addr.to_string(),
253                my_ip: my_ip.to_string(),
254                kv,
255                internal_notifier,
256                row,
257                _kill_notifier,
258                blob_receiver: Mutex::new(blob_receiver),
259                filter_results,
260            });
261
262            // spawn a tokio task to process messages
263            let ddf_clone = ddf.clone();
264            tokio::spawn(async move {
265                DistributedDataFrame::process_messages(
266                    ddf_clone,
267                    read_streams,
268                    blob_sender,
269                    filter_results_sender,
270                )
271                .await
272                .unwrap();
273            });
274
275            Ok(ddf)
276        } else {
277            // Node 1 will send the initialization message to our network
278            let init_msg = read_streams.next().await.unwrap()?;
279            // We got a message, check it was the initialization message
280            let (schema, df_chunk_map) = match init_msg.msg {
281                DistributedDFMsg::Initialization {
282                    schema,
283                    df_chunk_map,
284                } => (schema, df_chunk_map),
285                _ => return Err(LiquidError::UnexpectedMessage),
286            };
287            debug!("Got the Initialization message from Node 1");
288
289            let row = Arc::new(RwLock::new(Row::new(&schema)));
290            let num_rows = df_chunk_map.iter().fold(0, |mut acc, (k, _)| {
291                if acc > k.end {
292                    acc
293                } else {
294                    acc = k.end;
295                    acc
296                }
297            });
298
299            let ddf = Arc::new(DistributedDataFrame {
300                schema,
301                df_name: df_name.to_string(),
302                df_chunk_map,
303                num_rows,
304                network,
305                node_id,
306                num_nodes,
307                server_addr: server_addr.to_string(),
308                my_ip: my_ip.to_string(),
309                kv,
310                internal_notifier,
311                row,
312                _kill_notifier,
313                blob_receiver: Mutex::new(blob_receiver),
314                filter_results,
315            });
316
317            // spawn a tokio task to process messages
318            let ddf_clone = ddf.clone();
319            tokio::spawn(async move {
320                DistributedDataFrame::process_messages(
321                    ddf_clone,
322                    read_streams,
323                    blob_sender,
324                    filter_results_sender,
325                )
326                .await
327                .unwrap();
328            });
329
330            Ok(ddf)
331        }
332    }
333
334    // TODO: add some verification that the `data` is not jagged. A function
335    //       that is a no-op if its not jagged, otherwise inserts nulls to fix
336    //       it, would be nice.
337
338    /// Creates a new `DistributedDataFrame` by chunking the given `data` into
339    /// evenly sized chunks and distributing it across all nodes. Each chunk
340    /// will be size of total number of rows in `data` divided by the number of
341    /// nodes, since this was found to have the best performance for `map` and
342    /// `filter`. Node 1 is responsible for distributing the data, and thus
343    /// `data` should only be `Some` on node 1.
344    ///
345    /// NOTE: this function currently does not verify that `data` is not
346    /// jagged, which is a required invariant of the program. There is a plan
347    /// to automatically fix jagged data.
348    pub(crate) async fn new(
349        server_addr: &str,
350        my_ip: &str,
351        data: Option<Vec<Column>>,
352        kv: Arc<KVStore<LocalDataFrame>>,
353        df_name: &str,
354        num_nodes: usize,
355    ) -> Result<Arc<Self>, LiquidError> {
356        let num_rows = if let Some(d) = &data { n_rows(d) } else { 0 };
357        let chunk_size = num_rows / num_nodes;
358        let chunkerator = if data.is_some() {
359            Some(DataChunkerator { chunk_size, data })
360        } else {
361            None
362        };
363        DistributedDataFrame::from_iter(
364            server_addr,
365            my_ip,
366            chunkerator,
367            kv,
368            df_name,
369            num_nodes,
370        )
371        .await
372    }
373
374    /// Obtains a reference to this `DistributedDataFrame`s schema.
375    pub fn get_schema(&self) -> &Schema {
376        &self.schema
377    }
378
379    /// Get the data at the given `col_idx`, `row_idx` offsets as a boxed value
380    pub async fn get(
381        &self,
382        col_idx: usize,
383        row_idx: usize,
384    ) -> Result<Data, LiquidError> {
385        let r = self.get_row(row_idx).await?;
386        Ok(r.get(col_idx)?.clone())
387    }
388
389    /// Returns a clone of the row at the requested `index`
390    pub async fn get_row(&self, index: usize) -> Result<Row, LiquidError> {
391        match self.df_chunk_map.iter().find(|(k, _)| k.contains(&index)) {
392            Some((range, key)) => {
393                // key is either owned by us or another node
394                if key.home == self.node_id {
395                    // we own it
396                    let our_local_df = self.kv.get(&key).await?;
397                    let mut r = Row::new(self.get_schema());
398                    // TODO: is this index for fill_row correct?
399                    our_local_df.fill_row(index - range.start, &mut r)?;
400                    Ok(r)
401                } else {
402                    // owned by another node, must request over the network
403                    let get_msg = DistributedDFMsg::GetRow(index);
404                    {
405                        self.network
406                            .lock()
407                            .await
408                            .send_msg(key.home, get_msg)
409                            .await?;
410                    }
411                    // wait here until we are notified the row is set by our
412                    // message processing task
413                    self.internal_notifier.notified().await;
414                    // self.row is now set
415                    Ok(self.row.read().await.clone())
416                }
417            }
418            None => Err(LiquidError::RowIndexOutOfBounds),
419        }
420    }
421
422    /// Get the index of the `Column` with the given `col_name`. Returns `Some`
423    /// if a `Column` with the given name exists, or `None` otherwise.
424    pub fn get_col_idx(&self, col_name: &str) -> Option<usize> {
425        self.schema.col_idx(col_name)
426    }
427
428    /// Perform a distributed map operation on this `DistributedDataFrame` with
429    /// the given `rower`. Returns `Some(rower)` (of the joined results) if the
430    /// `node_id` of this `DistributedDataFrame` is `1`, and `None` otherwise.
431    ///
432    /// A local `pmap` is used on each node to map over that nodes' chunk.
433    /// By default, each node will use the number of threads available on that
434    /// machine.
435    ///
436    ///
437    /// NOTE:
438    /// There is an important design decision that comes with a distinct trade
439    /// off here. The trade off is:
440    /// 1. Join the last node with the next one until you get to the end. This
441    ///    has reduced memory requirements but a performance impact because
442    ///    of the synchronous network calls
443    /// 2. Join all nodes with one node by sending network messages
444    ///    concurrently to the final node. This has increased memory
445    ///    requirements and greater complexity but greater performance because
446    ///    all nodes can asynchronously send to one node at the same time.
447    ///
448    /// This implementation went with option 1 for simplicity reasons
449    pub async fn map<T: Rower + Clone + Send + Serialize + DeserializeOwned>(
450        &self,
451        mut rower: T,
452    ) -> Result<Option<T>, LiquidError> {
453        // get the keys for our locally owned chunks
454        let my_keys: Vec<&Key> = self
455            .df_chunk_map
456            .iter()
457            .filter(|(_, key)| key.home == self.node_id)
458            .map(|(_, v)| v)
459            .collect();
460        // map over our chunks
461        for key in my_keys {
462            // TODO: shouldn't need wait_and_get here since we own that chunk..
463            let ldf = self.kv.wait_and_get(key).await?;
464            rower = ldf.pmap(rower);
465        }
466        if self.node_id == self.num_nodes {
467            // we are the last node
468            self.send_blob(self.node_id - 1, &rower).await?;
469            debug!("Last node sent its results");
470            Ok(None)
471        } else {
472            let blob =
473                { self.blob_receiver.lock().await.recv().await.unwrap() };
474            let external_rower: T = deserialize(&blob[..])?;
475            rower = rower.join(external_rower);
476            debug!("Received a resulting rower and joined it with local rower");
477            if self.node_id != 1 {
478                self.send_blob(self.node_id - 1, &rower).await?;
479                debug!("Forwarded the combined rower");
480                Ok(None)
481            } else {
482                debug!("Final node completed map");
483                Ok(Some(rower))
484            }
485        }
486    }
487
488    // TODO: maybe abstract this into an iterator and use the from_iter
489    //       function since a **lot** of code here is copy pasted from that.
490    //       One issue: filter needs to generate a client-type that is unique
491    //       to the filtered dataframe, but from_iter assumes the client-type
492    //       is `ddf`. We could make a private from_iter_and_type method
493    //       that also accepts the client-type, and then from_iter passes in
494    //       "ddf" while filter passes in the generated client-type
495
496    /// Perform a distributed filter operation on this `DistributedDataFrame`.
497    /// This function does not mutate the `DistributedDataFrame` in anyway,
498    /// instead, it creates a new `DistributedDataFrame` of the results. This
499    /// `DistributedDataFrame` is returned to every node so that the results
500    /// are consistent everywhere.
501    ///
502    /// A local `pfilter` is used on each node to filter over that nodes'
503    /// chunks.  By default, each node will use the number of threads available
504    /// on that machine.
505    ///
506    /// It is possible to re-write this to use a bit map of the rows that
507    /// should remain in the filtered result, but currently this just clones
508    /// the rows.
509    pub async fn filter<
510        T: Rower + Clone + Send + Serialize + DeserializeOwned,
511    >(
512        &self,
513        mut rower: T,
514    ) -> Result<Arc<Self>, LiquidError> {
515        // so that our network client can notify us when they get a Kill
516        // signal
517        let _kill_notifier = Arc::new(Notify::new());
518        let mut rng = rand::thread_rng();
519        let r = rng.gen::<i16>();
520        let new_name = format!("{}-filtered-{}", &self.df_name, r);
521        let df_network_name = format!("ddf-{}", new_name);
522        let (network, mut read_streams, __kill_notifier) =
523            Client::register_network(
524                self.kv.network.clone(),
525                df_network_name.to_string(),
526            )
527            .await?;
528        assert_eq!(self.node_id, { network.lock().await.id });
529
530        // get the keys for our locally owned chunks
531        let my_keys: Vec<&Key> = self
532            .df_chunk_map
533            .iter()
534            .filter(|(_, key)| key.home == self.node_id)
535            .map(|(_, v)| v)
536            .collect();
537        // NOTE: combines all chunks into one final chunk, may want to change
538        // to stay 1-1
539        // filter over our locally owned chunks
540        let mut filtered_ldf = LocalDataFrame::new(self.get_schema());
541        for key in &my_keys {
542            // TODO: should not really need wait_and_get here since we own that chunk?
543            let ldf = self.kv.wait_and_get(key).await?;
544            filtered_ldf = filtered_ldf.combine(ldf.pfilter(&mut rower))?;
545        }
546
547        // initialize some other required fields of self so as not to duplicate
548        // code in if branches
549        let (blob_sender, blob_receiver) = mpsc::channel(2);
550        // used for internal messaging processing so that the asynnchronous
551        // messaging task can notify other tasks when `self.row` is ready
552        let internal_notifier = Arc::new(Notify::new());
553        // for processing results of distributed filtering
554        let (filter_results_sender, filter_results) =
555            mpsc::channel(self.num_nodes);
556        let filter_results = Mutex::new(filter_results);
557
558        let num_rows_left = filtered_ldf.n_rows();
559        info!(
560            "Finished filtering {} local chunk(s), have {} rows after filter",
561            my_keys.len(),
562            num_rows_left
563        );
564
565        // put our result in our KVStore only if its not empty
566        let mut key = None;
567        if num_rows_left > 0 {
568            let k = Key::generate(&new_name, self.node_id);
569            key = Some(k.clone());
570            self.kv.put(k, filtered_ldf).await?;
571        }
572
573        if self.node_id == 1 {
574            // 2. collect all results from other nodes (insert ours first)
575            let mut df_chunk_map = HashMap::new();
576            let mut cur_num_rows = 0;
577            if let Some(key) = key {
578                df_chunk_map.insert(
579                    Range {
580                        start: cur_num_rows,
581                        end: cur_num_rows + num_rows_left,
582                    },
583                    key,
584                );
585                cur_num_rows += num_rows_left;
586            }
587
588            let mut results_received = 1;
589            // TODO: maybe a better way to pass around these results
590            {
591                let mut unlocked = self.filter_results.lock().await;
592                while results_received < self.num_nodes {
593                    let msg = unlocked.recv().await.unwrap();
594                    match msg {
595                        DistributedDFMsg::FilterResult {
596                            num_rows,
597                            filtered_df_key,
598                        } => {
599                            match filtered_df_key {
600                                Some(k) => {
601                                    df_chunk_map.insert(
602                                        Range {
603                                            start: cur_num_rows,
604                                            end: cur_num_rows + num_rows,
605                                        },
606                                        k,
607                                    );
608                                    cur_num_rows += num_rows;
609                                }
610                                None => {
611                                    assert_eq!(num_rows, 0);
612                                }
613                            }
614                            results_received += 1;
615                        }
616                        _ => return Err(LiquidError::UnexpectedMessage),
617                    }
618                    results_received += 1;
619                }
620                debug!("Got all filter results from other nodes");
621            }
622
623            // 3. broadcast initialization message
624
625            // Create an Initialization message that holds all the information
626            // related to this DistributedDataFrame, the Schema and the map
627            // of the range of indices that each chunk holds and the `Key`
628            // associated with that chunk
629            let intro_msg = DistributedDFMsg::Initialization {
630                schema: self.get_schema().clone(),
631                df_chunk_map: df_chunk_map.clone(),
632            };
633
634            // Broadcast the initialization message to all nodes
635            network.lock().await.broadcast(intro_msg).await?;
636            debug!("Node 1 sent the initialization message to all nodes");
637
638            // 4. initialize self
639            let row = Arc::new(RwLock::new(Row::new(self.get_schema())));
640            let num_rows = df_chunk_map.iter().fold(0, |mut acc, (k, _)| {
641                if acc > k.end {
642                    acc
643                } else {
644                    acc = k.end;
645                    acc
646                }
647            });
648
649            let ddf = Arc::new(DistributedDataFrame {
650                schema: self.get_schema().clone(),
651                df_name: new_name,
652                df_chunk_map,
653                num_rows,
654                network,
655                node_id: self.node_id,
656                num_nodes: self.num_nodes,
657                server_addr: self.server_addr.clone(),
658                my_ip: self.my_ip.clone(),
659                kv: self.kv.clone(),
660                internal_notifier,
661                row,
662                _kill_notifier,
663                blob_receiver: Mutex::new(blob_receiver),
664                filter_results,
665            });
666
667            // spawn a tokio task to process messages
668            let ddf_clone = ddf.clone();
669            tokio::spawn(async move {
670                DistributedDataFrame::process_messages(
671                    ddf_clone,
672                    read_streams,
673                    blob_sender,
674                    filter_results_sender,
675                )
676                .await
677                .unwrap();
678            });
679
680            Ok(ddf)
681        } else {
682            // send our filterresults to node 1
683            let results = DistributedDFMsg::FilterResult {
684                num_rows: num_rows_left,
685                filtered_df_key: key,
686            };
687            network.lock().await.send_msg(1, results).await?;
688            // Node 1 will send the initialization message to our network
689            let init_msg = read_streams.next().await.unwrap()?;
690            // We got a message, check it was the initialization message
691            let (schema, df_chunk_map) = match init_msg.msg {
692                DistributedDFMsg::Initialization {
693                    schema,
694                    df_chunk_map,
695                } => (schema, df_chunk_map),
696                _ => return Err(LiquidError::UnexpectedMessage),
697            };
698            debug!("Got the Initialization message from Node 1");
699
700            // 4. initialize self
701            let row = Arc::new(RwLock::new(Row::new(&schema)));
702            let num_rows = df_chunk_map.iter().fold(0, |mut acc, (k, _)| {
703                if acc > k.end {
704                    acc
705                } else {
706                    acc = k.end;
707                    acc
708                }
709            });
710
711            let ddf = Arc::new(DistributedDataFrame {
712                schema,
713                df_name: new_name,
714                df_chunk_map,
715                num_rows,
716                network,
717                node_id: self.node_id,
718                num_nodes: self.num_nodes,
719                server_addr: self.server_addr.clone(),
720                my_ip: self.my_ip.clone(),
721                kv: self.kv.clone(),
722                internal_notifier,
723                row,
724                _kill_notifier,
725                blob_receiver: Mutex::new(blob_receiver),
726                filter_results,
727            });
728
729            // spawn a tokio task to process messages
730            let ddf_clone = ddf.clone();
731            tokio::spawn(async move {
732                DistributedDataFrame::process_messages(
733                    ddf_clone,
734                    read_streams,
735                    blob_sender,
736                    filter_results_sender,
737                )
738                .await
739                .unwrap();
740            });
741
742            Ok(ddf)
743        }
744    }
745
746    /// Return the (total) number of rows across all nodes for this
747    /// `DistributedDataFrame`
748    pub fn n_rows(&self) -> usize {
749        self.num_rows
750    }
751
752    /// Return the number of columns in this `DistributedDataFrame`.
753    pub fn n_cols(&self) -> usize {
754        self.schema.width()
755    }
756
757    /// Sends the given `blob` to the `DistributedDataFrame` with the given
758    /// `target_id` This provides a lower level interface to facilitate other
759    /// kinds of messages, such as sending deserialized `Rower`s
760    async fn send_blob<T: Serialize>(
761        &self,
762        target_id: usize,
763        blob: &T,
764    ) -> Result<(), LiquidError> {
765        let blob = serialize(blob)?;
766        self.network
767            .lock()
768            .await
769            .send_msg(target_id, DistributedDFMsg::Blob(blob))
770            .await
771    }
772
773    /// Spawns a `tokio` task that processes `DistributedDFMsg` messages
774    /// When a message is received, a new `tokio` task is spawned to
775    /// handle processing of that message to reduce blocking of the message
776    /// receiving task, so that new messages can be read and processed
777    /// concurrently.
778    async fn process_messages(
779        ddf: Arc<DistributedDataFrame>,
780        mut read_streams: SelectAll<FramedStream<DistributedDFMsg>>,
781        blob_sender: Sender<Vec<u8>>,
782        filter_results_sender: Sender<DistributedDFMsg>,
783    ) -> Result<(), LiquidError> {
784        while let Some(Ok(msg)) = read_streams.next().await {
785            let mut blob_sender_clone = blob_sender.clone();
786            let mut filter_res_sender = filter_results_sender.clone();
787            let ddf2 = ddf.clone();
788            tokio::spawn(async move {
789                match msg.msg {
790                        DistributedDFMsg::GetRow(row_idx) => {
791                            let r = ddf2.get_row(row_idx).await.unwrap();
792                            {
793                                ddf2.network
794                                    .lock()
795                                    .await
796                                    .send_msg(
797                                        msg.sender_id,
798                                        DistributedDFMsg::Row(r),
799                                    )
800                                    .await
801                                    .unwrap();
802                            }
803                        },
804                        DistributedDFMsg::Row(row) => {
805                            {
806                                *ddf2.row.write().await = row;
807                            }
808                            ddf2.internal_notifier.notify();
809                        },
810                        DistributedDFMsg::Blob(blob) => {
811                            blob_sender_clone.send(blob).await.unwrap();
812                        },
813                        DistributedDFMsg::FilterResult { num_rows, filtered_df_key } => {
814                            filter_res_sender.send(DistributedDFMsg:: FilterResult { num_rows, filtered_df_key }).await.unwrap();
815                        }
816                        _ => panic!("Should always happen before message process loop is started"),
817                    }
818            });
819        }
820
821        Ok(())
822    }
823}
824
825/// A simple struct to help chunk `Vec<Column>` by a given number of rows
826#[derive(Debug)]
827struct DataChunkerator {
828    /// how many rows in each chunk
829    chunk_size: usize,
830    /// Optional because its assumed node 1 has the data
831    data: Option<Vec<Column>>,
832}
833
834impl Iterator for DataChunkerator {
835    type Item = Vec<Column>;
836
837    /// Advances this iterator by breaking off `self.chunk_size` rows of its
838    /// data until the data is empty. The last chunk may be less than
839    /// `self.chunk_size`
840    fn next(&mut self) -> Option<Self::Item> {
841        if let Some(data) = &mut self.data {
842            // we are node 1 and have the data
843            let cur_chunk_size = cmp::min(self.chunk_size, n_rows(&data));
844            if cur_chunk_size == 0 {
845                // the data has been consumed
846                None
847            } else {
848                // there is more data to chunk
849                let mut chunked_data = Vec::with_capacity(data.len());
850                for col in data {
851                    // will panic if rows_per_node is greater than i.len()
852                    let new_col = match col {
853                        Column::Int(i) => {
854                            Column::Int(i.drain(0..cur_chunk_size).collect())
855                        }
856                        Column::Bool(i) => {
857                            Column::Bool(i.drain(0..cur_chunk_size).collect())
858                        }
859                        Column::Float(i) => {
860                            Column::Float(i.drain(0..cur_chunk_size).collect())
861                        }
862                        Column::String(i) => {
863                            Column::String(i.drain(0..cur_chunk_size).collect())
864                        }
865                    };
866                    chunked_data.push(new_col);
867                }
868                Some(chunked_data)
869            }
870        } else {
871            // we are not node 1, we don't have the data
872            None
873        }
874    }
875}
876
877fn n_rows(data: &[Column]) -> usize {
878    match data.get(0) {
879        None => 0,
880        Some(x) => match x {
881            Column::Int(c) => c.len(),
882            Column::Float(c) => c.len(),
883            Column::Bool(c) => c.len(),
884            Column::String(c) => c.len(),
885        },
886    }
887}
888
889fn count_new_lines(file_name: &str) -> usize {
890    let mut buf_reader = BufReader::new(File::open(file_name).unwrap());
891    let mut new_lines = 0;
892
893    loop {
894        let bytes_read = buf_reader.fill_buf().unwrap();
895        let len = bytes_read.len();
896        if len == 0 {
897            return new_lines;
898        };
899        new_lines += bytecount::count(bytes_read, b'\n');
900        buf_reader.consume(len);
901    }
902}