#ifndef H_WEIS_CURVE
#define H_WEIS_CURVE
#include "common.h"
#include "repr.h"
enum CurveType
{
Generic,
AIsMinus3,
AIsZero,
BIsZero,
};
template <class E>
class WeierstrassCurve
{
CurveType cty;
E a;
E b;
std::vector<u64> subgroup_order_;
u8 order_len_;
public:
WeierstrassCurve(E a, E b, std::vector<u64> subgroup_order, u8 order_len) : a(a), b(b), subgroup_order_(subgroup_order), order_len_(order_len)
{
cty = CurveType::Generic;
if (a.is_zero())
{
cty = CurveType::AIsZero;
}
}
E const &get_a() const
{
return a;
}
E const &get_b() const
{
return b;
}
u8 order_len() const
{
return order_len_;
}
std::vector<u64> const &subgroup_order() const
{
return subgroup_order_;
}
E const &curve_a() const
{
return a;
}
CurveType ctype() const
{
return cty;
}
};
template <class E>
class CurvePoint
{
public:
E x;
E y;
E z;
CurvePoint(E x, E y, E z) : x(x), y(y), z(z) {}
CurvePoint(E x, E y) : CurvePoint(x, y, x.one())
{
if (x.is_zero() && y.is_zero()) {
x = x.zero();
y = y.one();
z = z.zero();
}
}
template <class C>
static CurvePoint<E> zero(C const &context)
{
return CurvePoint(E::zero(context), E::one(context), E::zero(context));
}
auto operator=(CurvePoint<E> const &other)
{
x = other.x;
y = other.y;
z = other.z;
}
std::tuple<E, E> xy() const
{
if (is_zero())
{
return tuple(x.zero(), x.zero());
}
auto point = *this;
point.normalize();
return tuple(point.x, point.y);
}
bool check_on_curve(WeierstrassCurve<E> const &wc) const
{
if (is_zero()) {
return true;
}
auto rhs = y;
rhs.square();
auto lhs = wc.get_b();
auto ax = x;
ax.mul(wc.get_a());
lhs.add(ax);
auto x_3 = x;
x_3.square();
x_3.mul(x);
lhs.add(x_3);
return rhs == lhs;
}
template <class C>
bool check_correct_subgroup(WeierstrassCurve<E> const &wc, C const &context) const
{
auto const p = mul(wc.subgroup_order(), wc, context);
return p.is_zero();
}
void serialize(u8 mod_byte_len, std::vector<u8> &data) const
{
auto const pair = xy();
std::get<0>(pair).serialize(mod_byte_len, data);
std::get<1>(pair).serialize(mod_byte_len, data);
}
bool is_zero() const
{
return z.is_zero();
}
bool is_normalized() const
{
if (is_zero())
{
return true;
}
auto const one = z.one();
return z == one;
}
void negate()
{
if (!is_zero())
{
y.negate();
}
}
void mul2(WeierstrassCurve<E> const &wc)
{
switch (wc.ctype())
{
case CurveType::Generic:
this->mul2_generic(wc);
break;
case CurveType::AIsZero:
this->mul2_a_is_zero();
break;
default:
unimplemented("only curve with A != 0 and B != 0 or just B != 0 are supported");
}
}
template <class C>
CurvePoint<E> mul(std::vector<u64> const &scalar, WeierstrassCurve<E> const &wc, C const &context) const
{
auto res = CurvePoint<E>::zero(context);
auto found_one = false;
for (auto it = RevBitIterator(scalar); it.before();)
{
auto i = *it;
if (found_one)
{
res.mul2(wc);
}
else
{
found_one = i;
}
if (i)
{
res.add(*this, wc, context);
}
}
return res;
}
template <class C>
void add(CurvePoint<E> const &b, WeierstrassCurve<E> const &wc, C const &context)
{
if (this->is_zero())
{
*this = b;
return;
}
else if (b.is_zero())
{
return;
}
if (b.z == E::one(context))
{
this->add_mixed(b, wc, context);
return;
}
auto z1z1 = this->z;
z1z1.square();
auto z2z2 = b.z;
z2z2.square();
auto u1 = this->x;
u1.mul(z2z2);
auto u2 = b.x;
u2.mul(z1z1);
auto s1 = this->y;
s1.mul(b.z);
s1.mul(z2z2);
auto s2 = b.y;
s2.mul(this->z);
s2.mul(z1z1);
if (u1 == u2 && s1 == s2)
{
this->mul2(wc);
}
else
{
if (u1 == u2)
{
*this = CurvePoint<E>::zero(context);
return;
}
auto h = u2;
h.sub(u1);
auto i = h;
i.mul2();
i.square();
auto j = h;
j.mul(i);
auto r = s2;
r.sub(s1);
r.mul2();
auto v = u1;
v.mul(i);
this->x = r;
this->x.square();
this->x.sub(j);
this->x.sub(v);
this->x.sub(v);
this->y = v;
this->y.sub(this->x);
this->y.mul(r);
s1.mul(j); s1.mul2();
this->y.sub(s1);
this->z.add(b.z);
this->z.square();
this->z.sub(z1z1);
this->z.sub(z2z2);
this->z.mul(h);
}
}
template <class C>
void add_mixed(CurvePoint<E> const &b, WeierstrassCurve<E> const &wc, C const &context)
{
if (b.is_zero())
{
return;
}
if (this->is_zero())
{
*this = b;
return;
}
if (b.z != E::one(context))
{
this->add(b, wc, context);
return;
}
auto z1z1 = this->z;
z1z1.square();
auto u2 = b.x;
u2.mul(z1z1);
auto s2 = b.y;
s2.mul(this->z);
s2.mul(z1z1);
if (this->x == u2 && this->y == s2)
{
this->mul2(wc);
}
else
{
auto h = u2;
h.sub(this->x);
auto hh = h;
hh.square();
auto i = hh;
i.mul2();
i.mul2();
auto j = h;
j.mul(i);
auto r = s2;
r.sub(this->y);
r.mul2();
auto v = this->x;
v.mul(i);
this->x = r;
this->x.square();
this->x.sub(j);
this->x.sub(v);
this->x.sub(v);
j.mul(this->y); j.mul2();
this->y = v;
this->y.sub(this->x);
this->y.mul(r);
this->y.sub(j);
this->z.add(h);
this->z.square();
this->z.sub(z1z1);
this->z.sub(hh);
}
}
private:
void normalize()
{
if (is_zero())
{
return;
}
auto const one = x.one();
if (z == one)
{
return;
}
E const z_inv = z.inverse().value_or(x.zero());
auto zinv_powered = z_inv;
zinv_powered.square();
x.mul(zinv_powered);
zinv_powered.mul(z_inv);
y.mul(zinv_powered);
z = one;
}
void mul2_generic(WeierstrassCurve<E> const &wc)
{
if (this->is_zero())
{
return;
}
auto a = x;
a.square();
auto b = y;
b.square();
auto c = b;
c.square();
auto z_2 = z;
z_2.square();
auto d = x;
d.add(b);
d.square();
d.sub(a);
d.sub(c);
d.mul2();
auto e = a;
e.mul2();
e.add(a);
auto a_z_4 = z_2;
a_z_4.square();
a_z_4.mul(wc.curve_a());
e.add(a_z_4);
auto t = d;
t.mul2();
auto f = e;
f.square();
f.sub(t);
this->x = f;
this->z.add(this->y);
this->z.square();
this->z.sub(b);
this->z.sub(z_2);
this->y = d;
this->y.sub(this->x);
this->y.mul(e);
c.mul2();
c.mul2();
c.mul2();
this->y.sub(c);
}
void mul2_a_is_zero()
{
if (this->is_zero())
{
return;
}
auto a = this->x;
a.square();
auto b = this->y;
b.square();
auto c = b;
c.square();
auto d = this->x;
d.add(b);
d.square();
d.sub(a);
d.sub(c);
d.mul2();
auto e = a;
e.mul2();
e.add(a);
auto f = e;
f.square();
this->z.mul(this->y);
this->z.mul2();
this->x = f;
this->x.sub(d);
this->x.sub(d);
this->y = d;
this->y.sub(this->x);
this->y.mul(e);
c.mul2();
c.mul2();
c.mul2();
this->y.sub(c);
}
};
#endif