orign 0.2.3

A globally distributed container orchestrator
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
use chrono::{DateTime, Utc};

use orign::config::GlobalConfig;
use serde_json::Value;
use std::error::Error;

pub async fn get_deployment(id: Option<String>) -> Result<(), Box<dyn Error>> {
    use prettytable::{format, Cell, Row, Table};
    use reqwest::Client;
    use serde_json::Value;

    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap();
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    // Construct the URL - either for a specific model or all models
    let url = match &id {
        Some(model_id) => format!("{}/v1/deployments/{}", server, model_id.trim()),
        None => format!("{}/v1/deployments", server),
    };

    let response = client
        .get(&url)
        .header("Authorization", bearer_token)
        .send()
        .await?;

    // Check status first before consuming the response
    let status = response.status();

    if status.is_success() {
        let mut resp_json: Value = response.json().await?;

        // If ID is provided, print the full model details as YAML
        if id.is_some() {
            remove_null_values(&mut resp_json);

            let yaml = serde_yaml::to_string(&resp_json)?;
            println!("{}", yaml);
            return Ok(());
        }
        // println!("{}", resp_json);

        // Extract the list of models from the response
        let models = resp_json["model_deployments"]
            .as_array()
            .cloned()
            .unwrap_or_default();

        if models.is_empty() {
            println!("No models found");
            return Ok(());
        }

        let mut table = Table::new();

        // Add table headers
        table.add_row(Row::new(vec![
            Cell::new("ID"),
            Cell::new("MODEL"),
            Cell::new("PROVIDER"),
            Cell::new("KIND"),
            Cell::new("REPLICAS"),
            Cell::new("VRAM REQUEST"),
            Cell::new("GPU TYPE"),
            Cell::new("DATA TYPE"),
            Cell::new("STATUS"),
        ]));

        // Iterate over models and add them to the table
        for model in models {
            let id = model["id"].as_str().unwrap_or("");
            let provider = model["provider"].as_str().unwrap_or("");
            let kind = model["kind"].as_str().unwrap_or("");
            let vram_request = model["vram_request"].as_str().unwrap_or("");
            let gpu_type = model["gpu_type"].as_str().unwrap_or("");
            let status = model["status"].as_str().unwrap_or("");
            // Extract model name from the appropriate params based on provider
            let model_name = match provider {
                "vllm" => model["vllm_params"]["model"].as_str().unwrap_or(""),
                "sentence-tf" => model["sentence_tf_params"]["model"].as_str().unwrap_or(""),
                _ => "",
            };

            let dtype = match provider {
                "vllm" => model["vllm_params"]["dtype"].as_str().unwrap_or(""),
                _ => "",
            };

            // Format replicas as "available/desired"
            let replicas = format!(
                "{}/{}",
                model["replicas"]["available"].as_i64().unwrap_or(0),
                model["replicas"]["desired"].as_i64().unwrap_or(0)
            );

            table.add_row(Row::new(vec![
                Cell::new(id),
                Cell::new(model_name),
                Cell::new(provider),
                Cell::new(kind),
                Cell::new(&replicas),
                Cell::new(vram_request),
                Cell::new(gpu_type),
                Cell::new(dtype),
                Cell::new(status),
            ]));
        }

        table.set_format(*format::consts::FORMAT_CLEAN);
        // Print the table to stdout
        table.printstd();
    } else {
        let error_text = response.text().await?;
        eprintln!("Error: {}", error_text);
    }

    Ok(())
}

pub async fn get_training(id: Option<String>) -> Result<(), Box<dyn Error>> {
    use crate::models::{TrainingJob, TrainingJobsResponse};
    use prettytable::{format, Cell, Row, Table};
    use reqwest::Client;

    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap(); // Ensure `server` is set in the config
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    // Construct the URL based on whether an ID is provided
    let url = match &id {
        Some(training_id) => format!("{}/v1/trainings/{}", server, training_id.trim()),
        None => format!("{}/v1/trainings", server),
    };

    let response = client
        .get(&url)
        .header("Authorization", bearer_token)
        .send()
        .await?;

    let status = response.status();

    if status.is_success() {
        if let Some(_) = id {
            // If ID is provided, parse the body as a single TrainingJob
            let job: TrainingJob = response.json().await?;
            let mut as_value = serde_json::to_value(&job)?;
            remove_null_values(&mut as_value);

            let yaml = serde_yaml::to_string(&as_value)?;
            println!("{}", yaml);
        } else {
            // If no ID is provided, parse the body as multiple TrainingJob entries
            let trainings: TrainingJobsResponse = response.json().await?;
            if trainings.jobs.is_empty() {
                println!("No trainings found");
                return Ok(());
            }

            let mut table = Table::new();

            // Add table headers
            table.add_row(Row::new(vec![
                Cell::new("ID"),
                Cell::new("NAME"),
                Cell::new("STATUS"),
                Cell::new("MODEL"),
                Cell::new("PROVIDER"),
                Cell::new("EPOCHS"),
                Cell::new("ADAPTER"),
                Cell::new("QUEUE"),
                Cell::new("CREATED"),
            ]));

            for training in trainings.jobs {
                let id = &training.id;
                let name = &training.name;
                let status = &training.status;
                let framework = &training.framework;

                // Use safe methods instead of direct .unwrap()
                let adapter = training.adapter.as_deref().unwrap_or("");
                let queue = training.queue.as_deref().unwrap_or("");

                let dt = DateTime::<Utc>::from_timestamp(training.created, 0);
                let created = if let Some(dt) = dt {
                    dt.format("%Y-%m-%d %H:%M:%S").to_string()
                } else {
                    "[Invalid Timestamp]".to_string()
                };

                // Determine the "model" field
                let model = match framework.as_str() {
                    "ms-swift" => training
                        .ms_swift_params
                        .as_ref()
                        .map(|params| params.model.clone())
                        .unwrap_or_default(),
                    "llama-factory" => training
                        .llama_factory_params
                        .as_ref()
                        .map(|params| params.model.clone())
                        .unwrap_or_default(),
                    _ => "".to_string(),
                };

                // Determine the number of epochs
                let epochs = if let Some(params) = &training.ms_swift_params {
                    params.num_train_epochs.to_string()
                } else if let Some(params) = &training.llama_factory_params {
                    // Adjust if you have an epoch field for llama_factory_params
                    "[N/A]".to_string()
                } else {
                    "-".to_string()
                };

                table.add_row(Row::new(vec![
                    Cell::new(id),
                    Cell::new(name),
                    Cell::new(status),
                    Cell::new(&model),
                    Cell::new(framework),
                    Cell::new(&epochs),
                    Cell::new(adapter),
                    Cell::new(queue),
                    Cell::new(&created),
                ]));
            }

            table.set_format(*format::consts::FORMAT_CLEAN);
            table.printstd();
        }
    } else {
        let error_text = response.text().await?;
        eprintln!("Error: {}", error_text);
    }

    Ok(())
}

pub async fn get_buffer(name: Option<String>) -> Result<(), Box<dyn Error>> {
    use orign::resources::v1::buffers::models::{V1ReplayBuffer, V1ReplayBuffersResponse};
    use prettytable::{format, Cell, Row, Table};
    use reqwest::Client;

    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap(); // Ensure `server` is set in the config
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    let url = match &name {
        Some(buffer_id) => {
            let parts: Vec<&str> = buffer_id.split('/').collect();
            if parts.len() != 2 {
                return Err("You must specify the name in the format 'namespace/name'.".into());
            }
            let (namespace, buffer_name) = (parts[0], parts[1]);

            let url = format!("{}/v1/buffers/{}/{}", server, namespace, buffer_name);
            url
        }
        None => format!("{}/v1/buffers", server),
    };

    let response = client
        .get(&url)
        .header("Authorization", bearer_token)
        .send()
        .await?;

    let status = response.status();

    if status.is_success() {
        // If an ID is provided, we expect a single buffer struct from the server
        if let Some(_) = name {
            let single_buffer: Value = response.json().await?;
            // If you'd still like to remove "nulls" before printing:
            // Convert to Value -> remove nulls -> convert to YAML
            let mut as_value = serde_json::to_value(&single_buffer)?;
            remove_null_values(&mut as_value);

            let yaml = serde_yaml::to_string(&as_value)?;
            println!("{}", yaml);
        } else {
            // If no ID is provided, parse the entire response as a ReplayBuffersResponse
            let replay_buffers: V1ReplayBuffersResponse = response.json().await?;

            if replay_buffers.buffers.is_empty() {
                println!("No buffers found");
                return Ok(());
            }

            let mut table = Table::new();

            // Add table headers
            table.add_row(Row::new(vec![
                Cell::new("ID"),
                Cell::new("NAME"),
                Cell::new("MODEL"),
                Cell::new("PROVIDER"),
                Cell::new("EPOCHS"),
                Cell::new("NUM_EXAMPLES"),
                Cell::new("TRAIN_IDX"),
                Cell::new("SAMPLE_STRATEGY"),
                Cell::new("SAMPLE_N"),
                Cell::new("TRAIN_EVERY"),
            ]));

            // Iterate over the buffers and add them to the table
            for buffer in replay_buffers.buffers {
                // You would need to adjust how you derive "status," "model," or "epochs"
                // from ReplayBuffer if those fields differ from the old JSON approach
                let id = buffer.metadata.id.as_str();
                let name = buffer.metadata.name.as_str();
                let num_examples = buffer.status.num_records.unwrap_or(0);
                let train_idx = buffer.status.train_idx.unwrap_or(0);
                // The status could be stored differently; you’d adapt accordingly:
                let sample_strategy = buffer.sample_strategy.as_str();
                let sample_n = buffer.sample_n.to_string();
                let train_every = buffer.train_every.unwrap_or(0).to_string();
                let owner = buffer.metadata.namespace.as_str();

                table.add_row(Row::new(vec![
                    Cell::new(id),
                    Cell::new(&format!("{}/{}", owner, name)),
                    Cell::new(&num_examples.to_string()),
                    Cell::new(&train_idx.to_string()),
                    Cell::new(sample_strategy),
                    Cell::new(&sample_n),
                    Cell::new(&train_every),
                ]));
            }

            table.set_format(*format::consts::FORMAT_CLEAN);
            table.printstd();
        }
    } else {
        let error_text = response.text().await?;
        eprintln!("Error: {}", error_text);
    }

    Ok(())
}

pub async fn get_models() -> Result<(), Box<dyn Error>> {
    use reqwest::Client;
    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap();
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    let url = format!("{}/v1/models", server);

    let response = client
        .get(&url)
        .header("Authorization", bearer_token)
        .send() // Using `.await?` here
        .await?;

    let status = response.status();

    if status.is_success() {
        let resp_json: Value = response.json().await?;
        println!("{}", resp_json);
    }

    Ok(())
}

pub async fn get_datasets() -> Result<(), Box<dyn Error>> {
    use reqwest::Client;
    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap();
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    let url = format!("{}/v1/datasets", server);

    let response = client
        .get(&url)
        .header("Authorization", bearer_token)
        .send()
        .await?;

    let status = response.status();

    if status.is_success() {
        let resp_json: Value = response.json().await?;
        println!("{}", resp_json);
    }

    Ok(())
}

pub async fn get_adapters() -> Result<(), Box<dyn Error>> {
    use reqwest::Client;
    let client = Client::new();
    let config = GlobalConfig::read()?;
    let server = config.server.unwrap();
    let api_key = config.api_key.as_deref().ok_or("API key not set")?;
    let bearer_token = format!("Bearer {}", api_key);

    let url = format!("{}/v1/adapters", server);

    let response = client
        .get(&url)
        .header("Authorization", bearer_token)
        .send()
        .await?;

    let status = response.status();

    if status.is_success() {
        let resp_json: Value = response.json().await?;
        println!("{}", resp_json);
    }

    Ok(())
}

// Function to recursively remove null values from serde_json::Value
fn remove_null_values(value: &mut Value) {
    match value {
        Value::Object(map) => {
            // Collect keys with null values
            let keys_with_nulls: Vec<_> = map
                .iter()
                .filter_map(|(k, v)| if v.is_null() { Some(k.clone()) } else { None })
                .collect();

            // Remove keys with null values
            for k in keys_with_nulls {
                map.remove(&k);
            }

            // Recursively process the remaining values
            for v in map.values_mut() {
                remove_null_values(v);
            }
        }
        Value::Array(arr) => {
            // Recursively process each item in the array
            for v in arr.iter_mut() {
                remove_null_values(v);
            }
        }
        _ => {}
    }
}