1use crate::{
2 error::OomError,
3 oomagent,
4 oomagent::{
5 oom_agent_client::OomAgentClient,
6 ChannelExportRequest,
7 ChannelExportResponse,
8 ChannelImportRequest,
9 ChannelJoinRequest,
10 ChannelJoinResponse,
11 ExportRequest,
12 FeatureValueMap,
13 HealthCheckRequest,
14 ImportRequest,
15 JoinRequest,
16 OnlineGetRequest,
17 OnlineMultiGetRequest,
18 PushRequest,
19 SnapshotRequest,
20 SyncRequest,
21 },
22 server::ServerWrapper,
23 util::{parse_raw_feature_values, parse_raw_values},
24 EntityRow,
25 Result,
26 Value,
27};
28use async_stream::stream;
29use futures_core::stream::Stream;
30use std::{collections::HashMap, path::Path, sync::Arc};
31use tonic::{codegen::StdError, transport, Request};
32
33#[derive(Debug, Clone)]
36pub struct Client {
37 client: OomAgentClient<transport::Channel>,
38 _agent: Option<Arc<ServerWrapper>>,
39}
40
41impl Client {
43 pub async fn connect<D>(dst: D) -> Result<Self>
45 where
46 D: std::convert::TryInto<tonic::transport::Endpoint>,
47 D::Error: Into<StdError>,
48 {
49 Ok(Self { client: OomAgentClient::connect(dst).await?, _agent: None })
50 }
51
52 pub async fn with_embedded_oomagent<P1, P2>(bin_path: Option<P1>, cfg_path: Option<P2>) -> Result<Self>
54 where
55 P1: AsRef<Path>,
56 P2: AsRef<Path>,
57 {
58 let agent = ServerWrapper::new(bin_path, cfg_path, None).await?;
59 Ok(Self {
60 client: OomAgentClient::connect(format!("http://{}", agent.address())).await?,
61 _agent: Some(Arc::new(agent)),
62 })
63 }
64
65 pub async fn with_default_embedded_oomagent() -> Result<Self> {
67 Self::with_embedded_oomagent(None::<String>, None::<String>).await
68 }
69
70 pub async fn health_check(&mut self) -> Result<()> {
72 Ok(self.client.health_check(HealthCheckRequest {}).await.map(|_| ())?)
73 }
74
75 pub async fn online_get_raw(
77 &mut self,
78 entity_key: impl Into<String>,
79 features: Vec<String>,
80 ) -> Result<FeatureValueMap> {
81 let res = self
82 .client
83 .online_get(OnlineGetRequest { entity_key: entity_key.into(), features })
84 .await?
85 .into_inner();
86 Ok(match res.result {
87 Some(res) => res,
88 None => FeatureValueMap::default(),
89 })
90 }
91
92 pub async fn online_get(
94 &mut self,
95 key: impl Into<String>,
96 features: Vec<String>,
97 ) -> Result<HashMap<String, Option<Value>>> {
98 let rs = self.online_get_raw(key, features).await?;
99 Ok(parse_raw_feature_values(rs))
100 }
101
102 pub async fn online_multi_get_raw(
104 &mut self,
105 entity_keys: Vec<String>,
106 features: Vec<String>,
107 ) -> Result<HashMap<String, FeatureValueMap>> {
108 let res = self
109 .client
110 .online_multi_get(OnlineMultiGetRequest { entity_keys, features })
111 .await?
112 .into_inner();
113 Ok(res.result)
114 }
115
116 pub async fn online_multi_get(
118 &mut self,
119 keys: Vec<String>,
120 features: Vec<String>,
121 ) -> Result<HashMap<String, HashMap<String, Option<Value>>>> {
122 let rs = self.online_multi_get_raw(keys, features).await?;
123 Ok(rs.into_iter().map(|(k, v)| (k, parse_raw_feature_values(v))).collect())
124 }
125
126 pub async fn sync(
128 &mut self,
129 group: impl Into<String>,
130 revision_id: impl Into<Option<u32>>,
131 purge_delay: u32,
132 ) -> Result<()> {
133 let group = group.into();
134 let revision_id = revision_id.into().map(i32::try_from).transpose()?;
135 let purge_delay = i32::try_from(purge_delay)?;
136 self.client
137 .sync(SyncRequest { revision_id, group, purge_delay })
138 .await?;
139 Ok(())
140 }
141
142 pub async fn channel_import(
144 &mut self,
145 group: impl Into<String>,
146 revision: impl Into<Option<i64>>,
147 description: impl Into<Option<String>>,
148 rows: impl Stream<Item = Vec<u8>> + Send + 'static,
149 ) -> Result<u32> {
150 let mut group = Some(group.into());
151 let mut description = description.into();
152 let mut revision = revision.into();
153 let inbound = stream! {
154 for await row in rows {
155 yield ChannelImportRequest{group: group.take(), description: description.take(), revision: revision.take(), row};
156 }
157 };
158 let res = self.client.channel_import(Request::new(inbound)).await?.into_inner();
159 Ok(res.revision_id as u32)
160 }
161
162 pub async fn import(
164 &mut self,
165 group: impl Into<String>,
166 revision: impl Into<Option<i64>>,
167 description: impl Into<Option<String>>,
168 input_file: impl AsRef<Path>,
169 delimiter: impl Into<Option<char>>,
170 ) -> Result<u32> {
171 let res = self
172 .client
173 .import(ImportRequest {
174 group: group.into(),
175 description: description.into(),
176 revision: revision.into(),
177 input_file: input_file.as_ref().display().to_string(),
178 delimiter: delimiter.into().map(String::from),
179 })
180 .await?
181 .into_inner();
182 Ok(res.revision_id as u32)
183 }
184
185 pub async fn push(
187 &mut self,
188 entity_key: impl Into<String>,
189 group: impl Into<String>,
190 kv_pairs: HashMap<String, Value>,
191 ) -> Result<()> {
192 let kv_pairs = kv_pairs
193 .into_iter()
194 .map(|(k, v)| (k, oomagent::Value { value: Some(v) }))
195 .collect();
196 self.client
197 .push(PushRequest {
198 entity_key: entity_key.into(),
199 group: group.into(),
200 feature_values: kv_pairs,
201 })
202 .await?
203 .into_inner();
204 Ok(())
205 }
206
207 pub async fn channel_join(
209 &mut self,
210 join_features: Vec<String>,
211 existed_features: Vec<String>,
212 entity_rows: impl Stream<Item = EntityRow> + Send + 'static,
213 ) -> Result<(Vec<String>, impl Stream<Item = Result<Vec<Option<Value>>>>)> {
214 let mut join_features = Some(join_features);
215 let mut existed_features = Some(existed_features);
216 let inbound = stream! {
217 for await row in entity_rows {
218 let (join_features, existed_features) = match (join_features.take(), existed_features.take()) {
219 (Some(join_features), Some(existed_features)) => (join_features, existed_features),
220 _ => (Vec::new(), Vec::new()),
221 };
222 yield ChannelJoinRequest {
223 join_features,
224 existed_features,
225 entity_row: Some(row),
226 };
227 }
228 };
229
230 let mut outbound = self.client.channel_join(Request::new(inbound)).await?.into_inner();
231
232 let ChannelJoinResponse { header, joined_row } = outbound
233 .message()
234 .await?
235 .ok_or_else(|| OomError::Unknown(String::from("stream finished with no response")))?;
236
237 let row = parse_raw_values(joined_row);
238
239 let outbound = async_stream::try_stream! {
240 yield row;
241 while let Some(ChannelJoinResponse { joined_row, .. }) = outbound.message().await? {
242 yield parse_raw_values(joined_row)
243 }
244 };
245 Ok((header, outbound))
246 }
247
248 pub async fn join(
250 &mut self,
251 features: Vec<String>,
252 input_file: impl AsRef<Path>,
253 output_file: impl AsRef<Path>,
254 ) -> Result<()> {
255 self.client
256 .join(JoinRequest {
257 features,
258 input_file: input_file.as_ref().display().to_string(),
259 output_file: output_file.as_ref().display().to_string(),
260 })
261 .await?;
262 Ok(())
263 }
264
265 pub async fn channel_export(
267 &mut self,
268 features: Vec<String>,
269 unix_milli: u64,
270 limit: impl Into<Option<usize>>,
271 ) -> Result<(Vec<String>, impl Stream<Item = Result<Vec<Option<Value>>>>)> {
272 let unix_milli = unix_milli.try_into()?;
273 let limit = limit.into().map(|n| n.try_into()).transpose()?;
274 let mut outbound = self
275 .client
276 .channel_export(ChannelExportRequest { features, unix_milli, limit })
277 .await?
278 .into_inner();
279
280 let ChannelExportResponse { header, row } = outbound
281 .message()
282 .await?
283 .ok_or_else(|| OomError::Unknown(String::from("stream finished with no response")))?;
284
285 let row = parse_raw_values(row);
286 let outbound = async_stream::try_stream! {
287 yield row;
288 while let Some(ChannelExportResponse{row, ..}) = outbound.message().await? {
289 yield parse_raw_values(row)
290 }
291 };
292 Ok((header, outbound))
293 }
294
295 pub async fn export(
297 &mut self,
298 features: Vec<String>,
299 unix_milli: u64,
300 output_file: impl AsRef<Path>,
301 limit: impl Into<Option<usize>>,
302 ) -> Result<()> {
303 let unix_milli = unix_milli.try_into()?;
304 let limit = limit.into().map(|n| n.try_into()).transpose()?;
305 let output_file = output_file.as_ref().display().to_string();
306 self.client
307 .export(ExportRequest { features, unix_milli, output_file, limit })
308 .await?;
309 Ok(())
310 }
311
312 pub async fn snapshot(&mut self, group: impl Into<String>) -> Result<()> {
314 self.client.snapshot(SnapshotRequest { group: group.into() }).await?;
315 Ok(())
316 }
317}