oomclient/
client.rs

1use crate::{
2    error::OomError,
3    oomagent,
4    oomagent::{
5        oom_agent_client::OomAgentClient,
6        ChannelExportRequest,
7        ChannelExportResponse,
8        ChannelImportRequest,
9        ChannelJoinRequest,
10        ChannelJoinResponse,
11        ExportRequest,
12        FeatureValueMap,
13        HealthCheckRequest,
14        ImportRequest,
15        JoinRequest,
16        OnlineGetRequest,
17        OnlineMultiGetRequest,
18        PushRequest,
19        SnapshotRequest,
20        SyncRequest,
21    },
22    server::ServerWrapper,
23    util::{parse_raw_feature_values, parse_raw_values},
24    EntityRow,
25    Result,
26    Value,
27};
28use async_stream::stream;
29use futures_core::stream::Stream;
30use std::{collections::HashMap, path::Path, sync::Arc};
31use tonic::{codegen::StdError, transport, Request};
32
33/// A rust client for [oomstore](https://github.com/oom-ai/oomstore),
34/// using the grpc protocol to communicate with oomagent server under the hood.
35#[derive(Debug, Clone)]
36pub struct Client {
37    client: OomAgentClient<transport::Channel>,
38    _agent: Option<Arc<ServerWrapper>>,
39}
40
41// TODO: Add a Builder to create the client
42impl Client {
43    /// Connect to an oomagent instance running on the given endpoint.
44    pub async fn connect<D>(dst: D) -> Result<Self>
45    where
46        D: std::convert::TryInto<tonic::transport::Endpoint>,
47        D::Error: Into<StdError>,
48    {
49        Ok(Self { client: OomAgentClient::connect(dst).await?, _agent: None })
50    }
51
52    /// Connect to an oomagent instance embedded with the client.
53    pub async fn with_embedded_oomagent<P1, P2>(bin_path: Option<P1>, cfg_path: Option<P2>) -> Result<Self>
54    where
55        P1: AsRef<Path>,
56        P2: AsRef<Path>,
57    {
58        let agent = ServerWrapper::new(bin_path, cfg_path, None).await?;
59        Ok(Self {
60            client: OomAgentClient::connect(format!("http://{}", agent.address())).await?,
61            _agent: Some(Arc::new(agent)),
62        })
63    }
64
65    /// Connect to an embedded oomagent instance with default config.
66    pub async fn with_default_embedded_oomagent() -> Result<Self> {
67        Self::with_embedded_oomagent(None::<String>, None::<String>).await
68    }
69
70    /// Check if oomagent is ready to serve requests.
71    pub async fn health_check(&mut self) -> Result<()> {
72        Ok(self.client.health_check(HealthCheckRequest {}).await.map(|_| ())?)
73    }
74
75    /// Get online features for an entity without further conversion.
76    pub async fn online_get_raw(
77        &mut self,
78        entity_key: impl Into<String>,
79        features: Vec<String>,
80    ) -> Result<FeatureValueMap> {
81        let res = self
82            .client
83            .online_get(OnlineGetRequest { entity_key: entity_key.into(), features })
84            .await?
85            .into_inner();
86        Ok(match res.result {
87            Some(res) => res,
88            None => FeatureValueMap::default(),
89        })
90    }
91
92    /// Get online features for an entity.
93    pub async fn online_get(
94        &mut self,
95        key: impl Into<String>,
96        features: Vec<String>,
97    ) -> Result<HashMap<String, Option<Value>>> {
98        let rs = self.online_get_raw(key, features).await?;
99        Ok(parse_raw_feature_values(rs))
100    }
101
102    /// Get online features for multiple entities without further conversion.
103    pub async fn online_multi_get_raw(
104        &mut self,
105        entity_keys: Vec<String>,
106        features: Vec<String>,
107    ) -> Result<HashMap<String, FeatureValueMap>> {
108        let res = self
109            .client
110            .online_multi_get(OnlineMultiGetRequest { entity_keys, features })
111            .await?
112            .into_inner();
113        Ok(res.result)
114    }
115
116    /// Get online features for multiple entities.
117    pub async fn online_multi_get(
118        &mut self,
119        keys: Vec<String>,
120        features: Vec<String>,
121    ) -> Result<HashMap<String, HashMap<String, Option<Value>>>> {
122        let rs = self.online_multi_get_raw(keys, features).await?;
123        Ok(rs.into_iter().map(|(k, v)| (k, parse_raw_feature_values(v))).collect())
124    }
125
126    /// Sync a certain revision of batch features from offline to online store.
127    pub async fn sync(
128        &mut self,
129        group: impl Into<String>,
130        revision_id: impl Into<Option<u32>>,
131        purge_delay: u32,
132    ) -> Result<()> {
133        let group = group.into();
134        let revision_id = revision_id.into().map(i32::try_from).transpose()?;
135        let purge_delay = i32::try_from(purge_delay)?;
136        self.client
137            .sync(SyncRequest { revision_id, group, purge_delay })
138            .await?;
139        Ok(())
140    }
141
142    /// Import features from external (batch and stream) data sources to offline store through channels.
143    pub async fn channel_import(
144        &mut self,
145        group: impl Into<String>,
146        revision: impl Into<Option<i64>>,
147        description: impl Into<Option<String>>,
148        rows: impl Stream<Item = Vec<u8>> + Send + 'static,
149    ) -> Result<u32> {
150        let mut group = Some(group.into());
151        let mut description = description.into();
152        let mut revision = revision.into();
153        let inbound = stream! {
154            for await row in rows {
155                yield ChannelImportRequest{group: group.take(), description: description.take(), revision: revision.take(), row};
156            }
157        };
158        let res = self.client.channel_import(Request::new(inbound)).await?.into_inner();
159        Ok(res.revision_id as u32)
160    }
161
162    /// Import features from external (batch and stream) data sources to offline store through files.
163    pub async fn import(
164        &mut self,
165        group: impl Into<String>,
166        revision: impl Into<Option<i64>>,
167        description: impl Into<Option<String>>,
168        input_file: impl AsRef<Path>,
169        delimiter: impl Into<Option<char>>,
170    ) -> Result<u32> {
171        let res = self
172            .client
173            .import(ImportRequest {
174                group:       group.into(),
175                description: description.into(),
176                revision:    revision.into(),
177                input_file:  input_file.as_ref().display().to_string(),
178                delimiter:   delimiter.into().map(String::from),
179            })
180            .await?
181            .into_inner();
182        Ok(res.revision_id as u32)
183    }
184
185    /// Push stream features from stream data source to both offline and online stores.
186    pub async fn push(
187        &mut self,
188        entity_key: impl Into<String>,
189        group: impl Into<String>,
190        kv_pairs: HashMap<String, Value>,
191    ) -> Result<()> {
192        let kv_pairs = kv_pairs
193            .into_iter()
194            .map(|(k, v)| (k, oomagent::Value { value: Some(v) }))
195            .collect();
196        self.client
197            .push(PushRequest {
198                entity_key:     entity_key.into(),
199                group:          group.into(),
200                feature_values: kv_pairs,
201            })
202            .await?
203            .into_inner();
204        Ok(())
205    }
206
207    /// Point-in-Time Join features against labeled entity rows through channels.
208    pub async fn channel_join(
209        &mut self,
210        join_features: Vec<String>,
211        existed_features: Vec<String>,
212        entity_rows: impl Stream<Item = EntityRow> + Send + 'static,
213    ) -> Result<(Vec<String>, impl Stream<Item = Result<Vec<Option<Value>>>>)> {
214        let mut join_features = Some(join_features);
215        let mut existed_features = Some(existed_features);
216        let inbound = stream! {
217            for await row in entity_rows {
218                let (join_features, existed_features) = match (join_features.take(), existed_features.take()) {
219                    (Some(join_features), Some(existed_features)) => (join_features, existed_features),
220                    _ => (Vec::new(), Vec::new()),
221                };
222                yield ChannelJoinRequest {
223                    join_features,
224                    existed_features,
225                    entity_row: Some(row),
226                };
227            }
228        };
229
230        let mut outbound = self.client.channel_join(Request::new(inbound)).await?.into_inner();
231
232        let ChannelJoinResponse { header, joined_row } = outbound
233            .message()
234            .await?
235            .ok_or_else(|| OomError::Unknown(String::from("stream finished with no response")))?;
236
237        let row = parse_raw_values(joined_row);
238
239        let outbound = async_stream::try_stream! {
240            yield row;
241            while let Some(ChannelJoinResponse { joined_row, .. }) = outbound.message().await? {
242                yield parse_raw_values(joined_row)
243            }
244        };
245        Ok((header, outbound))
246    }
247
248    /// Point-in-Time Join features against labeled entity rows through files.
249    pub async fn join(
250        &mut self,
251        features: Vec<String>,
252        input_file: impl AsRef<Path>,
253        output_file: impl AsRef<Path>,
254    ) -> Result<()> {
255        self.client
256            .join(JoinRequest {
257                features,
258                input_file: input_file.as_ref().display().to_string(),
259                output_file: output_file.as_ref().display().to_string(),
260            })
261            .await?;
262        Ok(())
263    }
264
265    /// Export certain features to a channel.
266    pub async fn channel_export(
267        &mut self,
268        features: Vec<String>,
269        unix_milli: u64,
270        limit: impl Into<Option<usize>>,
271    ) -> Result<(Vec<String>, impl Stream<Item = Result<Vec<Option<Value>>>>)> {
272        let unix_milli = unix_milli.try_into()?;
273        let limit = limit.into().map(|n| n.try_into()).transpose()?;
274        let mut outbound = self
275            .client
276            .channel_export(ChannelExportRequest { features, unix_milli, limit })
277            .await?
278            .into_inner();
279
280        let ChannelExportResponse { header, row } = outbound
281            .message()
282            .await?
283            .ok_or_else(|| OomError::Unknown(String::from("stream finished with no response")))?;
284
285        let row = parse_raw_values(row);
286        let outbound = async_stream::try_stream! {
287            yield row;
288            while let Some(ChannelExportResponse{row, ..}) = outbound.message().await? {
289                yield parse_raw_values(row)
290            }
291        };
292        Ok((header, outbound))
293    }
294
295    /// Export certain features to a file.
296    pub async fn export(
297        &mut self,
298        features: Vec<String>,
299        unix_milli: u64,
300        output_file: impl AsRef<Path>,
301        limit: impl Into<Option<usize>>,
302    ) -> Result<()> {
303        let unix_milli = unix_milli.try_into()?;
304        let limit = limit.into().map(|n| n.try_into()).transpose()?;
305        let output_file = output_file.as_ref().display().to_string();
306        self.client
307            .export(ExportRequest { features, unix_milli, output_file, limit })
308            .await?;
309        Ok(())
310    }
311
312    /// Take snapshot for a stream feature group in offline store.
313    pub async fn snapshot(&mut self, group: impl Into<String>) -> Result<()> {
314        self.client.snapshot(SnapshotRequest { group: group.into() }).await?;
315        Ok(())
316    }
317}