aleph_alpha_client/
explanation.rs1use serde::{Deserialize, Serialize};
2
3use crate::{Prompt, Task};
4
5pub struct TaskExplanation<'a> {
7 pub prompt: Prompt<'a>,
9 pub target: &'a str,
12 pub granularity: Granularity,
14}
15
16#[derive(Default)]
18pub struct Granularity {
19 prompt: PromptGranularity,
22}
23
24impl Granularity {
25 pub fn with_prompt_granularity(self, prompt_granularity: PromptGranularity) -> Self {
28 Self {
29 prompt: prompt_granularity,
30 }
31 }
32}
33
34#[derive(Serialize, Copy, Clone, PartialEq, Default)]
41#[serde(tag = "type", rename_all = "snake_case")]
42pub enum PromptGranularity {
43 #[default]
46 Auto,
47 Word,
48 Sentence,
49 Paragraph,
50}
51
52impl PromptGranularity {
53 fn is_auto(&self) -> bool {
54 self == &PromptGranularity::Auto
55 }
56}
57
58#[derive(Serialize)]
60struct BodyExplanation<'a> {
61 prompt: Prompt<'a>,
62 target: &'a str,
63 #[serde(skip_serializing_if = "PromptGranularity::is_auto")]
64 prompt_granularity: PromptGranularity,
65 model: &'a str,
66}
67
68#[derive(Deserialize, Debug, PartialEq)]
70pub struct ResponseExplanation {
71 explanations: Vec<Explanation>,
74}
75
76#[derive(Debug, PartialEq)]
78pub struct ExplanationOutput {
79 pub items: Vec<ItemExplanation>,
81}
82
83impl ExplanationOutput {
84 fn from(mut response: ResponseExplanation) -> ExplanationOutput {
85 ExplanationOutput {
86 items: response.explanations.pop().unwrap().items,
87 }
88 }
89}
90
91#[derive(Debug, Deserialize, PartialEq)]
93pub struct Explanation {
94 pub items: Vec<ItemExplanation>,
96}
97
98#[derive(PartialEq, Deserialize, Debug)]
103#[serde(tag = "type", rename_all = "snake_case")]
104pub enum ItemExplanation {
105 Text { scores: Vec<TextScore> },
106 Image { scores: Vec<ImageScore> },
107 Target { scores: Vec<TextScore> },
108}
109
110#[derive(Debug, PartialEq, Deserialize, Clone)]
112pub struct TextScore {
113 pub start: u32,
114 pub length: u32,
115 pub score: f32,
116}
117
118#[derive(Deserialize)]
120struct BoundingBox {
121 top: f32,
122 left: f32,
123 width: f32,
124 height: f32,
125}
126
127#[derive(Deserialize)]
129struct ImageScoreWithRect {
130 rect: BoundingBox,
131 score: f32,
132}
133
134#[derive(Debug, PartialEq, Deserialize, Clone)]
136#[serde(from = "ImageScoreWithRect")]
137pub struct ImageScore {
138 pub left: f32,
139 pub top: f32,
140 pub width: f32,
141 pub height: f32,
142 pub score: f32,
143}
144
145impl From<ImageScoreWithRect> for ImageScore {
146 fn from(value: ImageScoreWithRect) -> Self {
147 Self {
148 left: value.rect.left,
149 top: value.rect.top,
150 width: value.rect.width,
151 height: value.rect.height,
152 score: value.score,
153 }
154 }
155}
156
157impl Task for TaskExplanation<'_> {
158 type Output = ExplanationOutput;
159
160 type ResponseBody = ResponseExplanation;
161
162 fn build_request(
163 &self,
164 client: &reqwest::Client,
165 base: &str,
166 model: &str,
167 ) -> reqwest::RequestBuilder {
168 let body = BodyExplanation {
169 model,
170 prompt: self.prompt.borrow(),
171 target: self.target,
172 prompt_granularity: self.granularity.prompt,
173 };
174 client.post(format!("{base}/explain")).json(&body)
175 }
176
177 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
178 ExplanationOutput::from(response)
179 }
180}